1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements mlir::applyPatternsAndFoldGreedily.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/Interfaces/SideEffectInterfaces.h"
16 #include "mlir/Rewrite/PatternApplicator.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 #include "mlir/Transforms/RegionUtils.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/ScopedPrinter.h"
23 #include "llvm/Support/raw_ostream.h"
24
25 using namespace mlir;
26
27 #define DEBUG_TYPE "greedy-rewriter"
28
29 //===----------------------------------------------------------------------===//
30 // GreedyPatternRewriteDriver
31 //===----------------------------------------------------------------------===//
32
33 namespace {
34 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
35 /// applies the locally optimal patterns in a roughly "bottom up" way.
36 class GreedyPatternRewriteDriver : public PatternRewriter {
37 public:
38 explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
39 const FrozenRewritePatternSet &patterns,
40 const GreedyRewriteConfig &config);
41
42 /// Simplify the operations within the given regions.
43 bool simplify(MutableArrayRef<Region> regions);
44
45 /// Add the given operation to the worklist.
46 virtual void addToWorklist(Operation *op);
47
48 /// Pop the next operation from the worklist.
49 Operation *popFromWorklist();
50
51 /// If the specified operation is in the worklist, remove it.
52 void removeFromWorklist(Operation *op);
53
54 protected:
55 // Implement the hook for inserting operations, and make sure that newly
56 // inserted ops are added to the worklist for processing.
57 void notifyOperationInserted(Operation *op) override;
58
59 // Look over the provided operands for any defining operations that should
60 // be re-added to the worklist. This function should be called when an
61 // operation is modified or removed, as it may trigger further
62 // simplifications.
63 void addOperandsToWorklist(ValueRange operands);
64
65 // If an operation is about to be removed, make sure it is not in our
66 // worklist anymore because we'd get dangling references to it.
67 void notifyOperationRemoved(Operation *op) override;
68
69 // When the root of a pattern is about to be replaced, it can trigger
70 // simplifications to its users - make sure to add them to the worklist
71 // before the root is changed.
72 void notifyRootReplaced(Operation *op) override;
73
74 /// PatternRewriter hook for erasing a dead operation.
75 void eraseOp(Operation *op) override;
76
77 /// PatternRewriter hook for notifying match failure reasons.
78 LogicalResult
79 notifyMatchFailure(Location loc,
80 function_ref<void(Diagnostic &)> reasonCallback) override;
81
82 /// The low-level pattern applicator.
83 PatternApplicator matcher;
84
85 /// The worklist for this transformation keeps track of the operations that
86 /// need to be revisited, plus their index in the worklist. This allows us to
87 /// efficiently remove operations from the worklist when they are erased, even
88 /// if they aren't the root of a pattern.
89 std::vector<Operation *> worklist;
90 DenseMap<Operation *, unsigned> worklistMap;
91
92 /// Non-pattern based folder for operations.
93 OperationFolder folder;
94
95 private:
96 /// Configuration information for how to simplify.
97 GreedyRewriteConfig config;
98
99 #ifndef NDEBUG
100 /// A logger used to emit information during the application process.
101 llvm::ScopedPrinter logger{llvm::dbgs()};
102 #endif
103 };
104 } // namespace
105
GreedyPatternRewriteDriver(MLIRContext * ctx,const FrozenRewritePatternSet & patterns,const GreedyRewriteConfig & config)106 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
107 MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
108 const GreedyRewriteConfig &config)
109 : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
110 worklist.reserve(64);
111
112 // Apply a simple cost model based solely on pattern benefit.
113 matcher.applyDefaultCostModel();
114 }
115
simplify(MutableArrayRef<Region> regions)116 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
117 #ifndef NDEBUG
118 const char *logLineComment =
119 "//===-------------------------------------------===//\n";
120
121 /// A utility function to log a process result for the given reason.
122 auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
123 logger.unindent();
124 logger.startLine() << "} -> " << result;
125 if (!msg.isTriviallyEmpty())
126 logger.getOStream() << " : " << msg;
127 logger.getOStream() << "\n";
128 };
129 auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
130 logResult(result, msg);
131 logger.startLine() << logLineComment;
132 };
133 #endif
134
135 auto insertKnownConstant = [&](Operation *op) {
136 // Check for existing constants when populating the worklist. This avoids
137 // accidentally reversing the constant order during processing.
138 Attribute constValue;
139 if (matchPattern(op, m_Constant(&constValue)))
140 if (!folder.insertKnownConstant(op, constValue))
141 return true;
142 return false;
143 };
144
145 bool changed = false;
146 unsigned iteration = 0;
147 do {
148 worklist.clear();
149 worklistMap.clear();
150
151 if (!config.useTopDownTraversal) {
152 // Add operations to the worklist in postorder.
153 for (auto ®ion : regions) {
154 region.walk([&](Operation *op) {
155 if (!insertKnownConstant(op))
156 addToWorklist(op);
157 });
158 }
159 } else {
160 // Add all nested operations to the worklist in preorder.
161 for (auto ®ion : regions) {
162 region.walk<WalkOrder::PreOrder>([&](Operation *op) {
163 if (!insertKnownConstant(op)) {
164 worklist.push_back(op);
165 return WalkResult::advance();
166 }
167 return WalkResult::skip();
168 });
169 }
170
171 // Reverse the list so our pop-back loop processes them in-order.
172 std::reverse(worklist.begin(), worklist.end());
173 // Remember the reverse index.
174 for (size_t i = 0, e = worklist.size(); i != e; ++i)
175 worklistMap[worklist[i]] = i;
176 }
177
178 // These are scratch vectors used in the folding loop below.
179 SmallVector<Value, 8> originalOperands, resultValues;
180
181 changed = false;
182 while (!worklist.empty()) {
183 auto *op = popFromWorklist();
184
185 // Nulls get added to the worklist when operations are removed, ignore
186 // them.
187 if (op == nullptr)
188 continue;
189
190 LLVM_DEBUG({
191 logger.getOStream() << "\n";
192 logger.startLine() << logLineComment;
193 logger.startLine() << "Processing operation : '" << op->getName()
194 << "'(" << op << ") {\n";
195 logger.indent();
196
197 // If the operation has no regions, just print it here.
198 if (op->getNumRegions() == 0) {
199 op->print(
200 logger.startLine(),
201 OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
202 logger.getOStream() << "\n\n";
203 }
204 });
205
206 // If the operation is trivially dead - remove it.
207 if (isOpTriviallyDead(op)) {
208 notifyOperationRemoved(op);
209 op->erase();
210 changed = true;
211
212 LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
213 continue;
214 }
215
216 // Collects all the operands and result uses of the given `op` into work
217 // list. Also remove `op` and nested ops from worklist.
218 originalOperands.assign(op->operand_begin(), op->operand_end());
219 auto preReplaceAction = [&](Operation *op) {
220 // Add the operands to the worklist for visitation.
221 addOperandsToWorklist(originalOperands);
222
223 // Add all the users of the result to the worklist so we make sure
224 // to revisit them.
225 for (auto result : op->getResults())
226 for (auto *userOp : result.getUsers())
227 addToWorklist(userOp);
228
229 notifyOperationRemoved(op);
230 };
231
232 // Add the given operation to the worklist.
233 auto collectOps = [this](Operation *op) { addToWorklist(op); };
234
235 // Try to fold this op.
236 bool inPlaceUpdate;
237 if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
238 &inPlaceUpdate)))) {
239 LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
240
241 changed = true;
242 if (!inPlaceUpdate)
243 continue;
244 }
245
246 // Try to match one of the patterns. The rewriter is automatically
247 // notified of any necessary changes, so there is nothing else to do
248 // here.
249 #ifndef NDEBUG
250 auto canApply = [&](const Pattern &pattern) {
251 LLVM_DEBUG({
252 logger.getOStream() << "\n";
253 logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
254 << op->getName() << " -> (";
255 llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
256 logger.getOStream() << ")' {\n";
257 logger.indent();
258 });
259 return true;
260 };
261 auto onFailure = [&](const Pattern &pattern) {
262 LLVM_DEBUG(logResult("failure", "pattern failed to match"));
263 };
264 auto onSuccess = [&](const Pattern &pattern) {
265 LLVM_DEBUG(logResult("success", "pattern applied successfully"));
266 return success();
267 };
268
269 LogicalResult matchResult =
270 matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
271 if (succeeded(matchResult))
272 LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
273 else
274 LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
275 #else
276 LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
277 #endif
278 changed |= succeeded(matchResult);
279 }
280
281 // After applying patterns, make sure that the CFG of each of the regions
282 // is kept up to date.
283 if (config.enableRegionSimplification)
284 changed |= succeeded(simplifyRegions(*this, regions));
285 } while (changed &&
286 (iteration++ < config.maxIterations ||
287 config.maxIterations == GreedyRewriteConfig::kNoIterationLimit));
288
289 // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
290 return !changed;
291 }
292
addToWorklist(Operation * op)293 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
294 // Check to see if the worklist already contains this op.
295 if (worklistMap.count(op))
296 return;
297
298 worklistMap[op] = worklist.size();
299 worklist.push_back(op);
300 }
301
popFromWorklist()302 Operation *GreedyPatternRewriteDriver::popFromWorklist() {
303 auto *op = worklist.back();
304 worklist.pop_back();
305
306 // This operation is no longer in the worklist, keep worklistMap up to date.
307 if (op)
308 worklistMap.erase(op);
309 return op;
310 }
311
removeFromWorklist(Operation * op)312 void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
313 auto it = worklistMap.find(op);
314 if (it != worklistMap.end()) {
315 assert(worklist[it->second] == op && "malformed worklist data structure");
316 worklist[it->second] = nullptr;
317 worklistMap.erase(it);
318 }
319 }
320
notifyOperationInserted(Operation * op)321 void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
322 LLVM_DEBUG({
323 logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
324 << ")\n";
325 });
326 addToWorklist(op);
327 }
328
addOperandsToWorklist(ValueRange operands)329 void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
330 for (Value operand : operands) {
331 // If the use count of this operand is now < 2, we re-add the defining
332 // operation to the worklist.
333 // TODO: This is based on the fact that zero use operations
334 // may be deleted, and that single use values often have more
335 // canonicalization opportunities.
336 if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
337 continue;
338 if (auto *defOp = operand.getDefiningOp())
339 addToWorklist(defOp);
340 }
341 }
342
notifyOperationRemoved(Operation * op)343 void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
344 addOperandsToWorklist(op->getOperands());
345 op->walk([this](Operation *operation) {
346 removeFromWorklist(operation);
347 folder.notifyRemoval(operation);
348 });
349 }
350
notifyRootReplaced(Operation * op)351 void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) {
352 LLVM_DEBUG({
353 logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
354 << ")\n";
355 });
356 for (auto result : op->getResults())
357 for (auto *user : result.getUsers())
358 addToWorklist(user);
359 }
360
eraseOp(Operation * op)361 void GreedyPatternRewriteDriver::eraseOp(Operation *op) {
362 LLVM_DEBUG({
363 logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
364 << ")\n";
365 });
366 PatternRewriter::eraseOp(op);
367 }
368
notifyMatchFailure(Location loc,function_ref<void (Diagnostic &)> reasonCallback)369 LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
370 Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
371 LLVM_DEBUG({
372 Diagnostic diag(loc, DiagnosticSeverity::Remark);
373 reasonCallback(diag);
374 logger.startLine() << "** Failure : " << diag.str() << "\n";
375 });
376 return failure();
377 }
378
379 /// Rewrite the regions of the specified operation, which must be isolated from
380 /// above, by repeatedly applying the highest benefit patterns in a greedy
381 /// work-list driven manner. Return success if no more patterns can be matched
382 /// in the result operation regions. Note: This does not apply patterns to the
383 /// top-level operation itself.
384 ///
385 LogicalResult
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,const FrozenRewritePatternSet & patterns,GreedyRewriteConfig config)386 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
387 const FrozenRewritePatternSet &patterns,
388 GreedyRewriteConfig config) {
389 if (regions.empty())
390 return success();
391
392 // The top-level operation must be known to be isolated from above to
393 // prevent performing canonicalizations on operations defined at or above
394 // the region containing 'op'.
395 auto regionIsIsolated = [](Region ®ion) {
396 return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
397 };
398 (void)regionIsIsolated;
399 assert(llvm::all_of(regions, regionIsIsolated) &&
400 "patterns can only be applied to operations IsolatedFromAbove");
401
402 // Start the pattern driver.
403 GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
404 bool converged = driver.simplify(regions);
405 LLVM_DEBUG(if (!converged) {
406 llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
407 << config.maxIterations << " times\n";
408 });
409 return success(converged);
410 }
411
412 //===----------------------------------------------------------------------===//
413 // OpPatternRewriteDriver
414 //===----------------------------------------------------------------------===//
415
416 namespace {
417 /// This is a simple driver for the PatternMatcher to apply patterns and perform
418 /// folding on a single op. It repeatedly applies locally optimal patterns.
419 class OpPatternRewriteDriver : public PatternRewriter {
420 public:
OpPatternRewriteDriver(MLIRContext * ctx,const FrozenRewritePatternSet & patterns)421 explicit OpPatternRewriteDriver(MLIRContext *ctx,
422 const FrozenRewritePatternSet &patterns)
423 : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
424 // Apply a simple cost model based solely on pattern benefit.
425 matcher.applyDefaultCostModel();
426 }
427
428 LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
429
430 // These are hooks implemented for PatternRewriter.
431 protected:
432 /// If an operation is about to be removed, mark it so that we can let clients
433 /// know.
notifyOperationRemoved(Operation * op)434 void notifyOperationRemoved(Operation *op) override {
435 opErasedViaPatternRewrites = true;
436 }
437
438 // When a root is going to be replaced, its removal will be notified as well.
439 // So there is nothing to do here.
notifyRootReplaced(Operation * op)440 void notifyRootReplaced(Operation *op) override {}
441
442 private:
443 /// The low-level pattern applicator.
444 PatternApplicator matcher;
445
446 /// Non-pattern based folder for operations.
447 OperationFolder folder;
448
449 /// Set to true if the operation has been erased via pattern rewrites.
450 bool opErasedViaPatternRewrites = false;
451 };
452
453 } // namespace
454
455 /// Performs the rewrites and folding only on `op`. The simplification
456 /// converges if the op is erased as a result of being folded, replaced, or
457 /// becoming dead, or no more changes happen in an iteration. Returns success if
458 /// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
459 /// gets erased.
simplifyLocally(Operation * op,int maxIterations,bool & erased)460 LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
461 int maxIterations,
462 bool &erased) {
463 bool changed = false;
464 erased = false;
465 opErasedViaPatternRewrites = false;
466 int iterations = 0;
467 // Iterate until convergence or until maxIterations. Deletion of the op as
468 // a result of being dead or folded is convergence.
469 do {
470 changed = false;
471
472 // If the operation is trivially dead - remove it.
473 if (isOpTriviallyDead(op)) {
474 op->erase();
475 erased = true;
476 return success();
477 }
478
479 // Try to fold this op.
480 bool inPlaceUpdate;
481 if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
482 /*preReplaceAction=*/nullptr,
483 &inPlaceUpdate))) {
484 changed = true;
485 if (!inPlaceUpdate) {
486 erased = true;
487 return success();
488 }
489 }
490
491 // Try to match one of the patterns. The rewriter is automatically
492 // notified of any necessary changes, so there is nothing else to do here.
493 changed |= succeeded(matcher.matchAndRewrite(op, *this));
494 if ((erased = opErasedViaPatternRewrites))
495 return success();
496 } while (changed &&
497 (++iterations < maxIterations ||
498 maxIterations == GreedyRewriteConfig::kNoIterationLimit));
499
500 // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
501 return failure(changed);
502 }
503
504 //===----------------------------------------------------------------------===//
505 // MultiOpPatternRewriteDriver
506 //===----------------------------------------------------------------------===//
507
508 namespace {
509
510 /// This is a specialized GreedyPatternRewriteDriver to apply patterns and
511 /// perform folding for a supplied set of ops. It repeatedly simplifies while
512 /// restricting the rewrites to only the provided set of ops or optionally
513 /// to those directly affected by it (result users or operand providers).
514 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
515 public:
MultiOpPatternRewriteDriver(MLIRContext * ctx,const FrozenRewritePatternSet & patterns,bool strict)516 explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
517 const FrozenRewritePatternSet &patterns,
518 bool strict)
519 : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
520 strictMode(strict) {}
521
522 bool simplifyLocally(ArrayRef<Operation *> op);
523
addToWorklist(Operation * op)524 void addToWorklist(Operation *op) override {
525 if (!strictMode || strictModeFilteredOps.contains(op))
526 GreedyPatternRewriteDriver::addToWorklist(op);
527 }
528
529 private:
notifyOperationInserted(Operation * op)530 void notifyOperationInserted(Operation *op) override {
531 GreedyPatternRewriteDriver::notifyOperationInserted(op);
532 if (strictMode)
533 strictModeFilteredOps.insert(op);
534 }
535
notifyOperationRemoved(Operation * op)536 void notifyOperationRemoved(Operation *op) override {
537 GreedyPatternRewriteDriver::notifyOperationRemoved(op);
538 if (strictMode)
539 strictModeFilteredOps.erase(op);
540 }
541
542 /// If `strictMode` is true, any pre-existing ops outside of
543 /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
544 /// If `strictMode` is false, operations that use results of (or supply
545 /// operands to) any rewritten ops stemming from the simplification of the
546 /// provided ops are in turn simplified; any other ops still remain untouched
547 /// (i.e., regardless of `strictMode`).
548 bool strictMode = false;
549
550 /// The list of ops we are restricting our rewrites to if `strictMode` is on.
551 /// These include the supplied set of ops as well as new ops created while
552 /// rewriting those ops. This set is not maintained when strictMode is off.
553 llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
554 };
555
556 } // namespace
557
558 /// Performs the specified rewrites on `ops` while also trying to fold these ops
559 /// as well as any other ops that were in turn created due to these rewrite
560 /// patterns. Any pre-existing ops outside of `ops` remain completely
561 /// unmodified if `strictMode` is true. If `strictMode` is false, other
562 /// operations that use results of rewritten ops or supply operands to such ops
563 /// are in turn simplified; any other ops still remain unmodified (i.e.,
564 /// regardless of `strictMode`). Note that ops in `ops` could be erased as a
565 /// result of folding, becoming dead, or via pattern rewrites. Returns true if
566 /// at all any changes happened.
567 // Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
568 // or GreedyPatternRewriteDriver::simplify, this method just iterates until
569 // the worklist is empty. As our objective is to keep simplification "local",
570 // there is no strong rationale to re-add all operations into the worklist and
571 // rerun until an iteration changes nothing. If more widereaching simplification
572 // is desired, GreedyPatternRewriteDriver should be used.
simplifyLocally(ArrayRef<Operation * > ops)573 bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
574 if (strictMode) {
575 strictModeFilteredOps.clear();
576 strictModeFilteredOps.insert(ops.begin(), ops.end());
577 }
578
579 bool changed = false;
580 worklist.clear();
581 worklistMap.clear();
582 for (Operation *op : ops)
583 addToWorklist(op);
584
585 // These are scratch vectors used in the folding loop below.
586 SmallVector<Value, 8> originalOperands, resultValues;
587 while (!worklist.empty()) {
588 Operation *op = popFromWorklist();
589
590 // Nulls get added to the worklist when operations are removed, ignore
591 // them.
592 if (op == nullptr)
593 continue;
594
595 assert((!strictMode || strictModeFilteredOps.contains(op)) &&
596 "unexpected op was inserted under strict mode");
597
598 // If the operation is trivially dead - remove it.
599 if (isOpTriviallyDead(op)) {
600 notifyOperationRemoved(op);
601 op->erase();
602 changed = true;
603 continue;
604 }
605
606 // Collects all the operands and result uses of the given `op` into work
607 // list. Also remove `op` and nested ops from worklist.
608 originalOperands.assign(op->operand_begin(), op->operand_end());
609 auto preReplaceAction = [&](Operation *op) {
610 // Add the operands to the worklist for visitation.
611 addOperandsToWorklist(originalOperands);
612
613 // Add all the users of the result to the worklist so we make sure
614 // to revisit them.
615 for (Value result : op->getResults()) {
616 for (Operation *userOp : result.getUsers())
617 addToWorklist(userOp);
618 }
619
620 notifyOperationRemoved(op);
621 };
622
623 // Add the given operation generated by the folder to the worklist.
624 auto processGeneratedConstants = [this](Operation *op) {
625 notifyOperationInserted(op);
626 };
627
628 // Try to fold this op.
629 bool inPlaceUpdate;
630 if (succeeded(folder.tryToFold(op, processGeneratedConstants,
631 preReplaceAction, &inPlaceUpdate))) {
632 changed = true;
633 if (!inPlaceUpdate) {
634 // Op has been erased.
635 continue;
636 }
637 }
638
639 // Try to match one of the patterns. The rewriter is automatically
640 // notified of any necessary changes, so there is nothing else to do
641 // here.
642 changed |= succeeded(matcher.matchAndRewrite(op, *this));
643 }
644
645 return changed;
646 }
647
648 /// Rewrites only `op` using the supplied canonicalization patterns and
649 /// folding. `erased` is set to true if the op is erased as a result of being
650 /// folded, replaced, or dead.
applyOpPatternsAndFold(Operation * op,const FrozenRewritePatternSet & patterns,bool * erased)651 LogicalResult mlir::applyOpPatternsAndFold(
652 Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
653 // Start the pattern driver.
654 GreedyRewriteConfig config;
655 OpPatternRewriteDriver driver(op->getContext(), patterns);
656 bool opErased;
657 LogicalResult converged =
658 driver.simplifyLocally(op, config.maxIterations, opErased);
659 if (erased)
660 *erased = opErased;
661 LLVM_DEBUG(if (failed(converged)) {
662 llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
663 << config.maxIterations << " times";
664 });
665 return converged;
666 }
667
applyOpPatternsAndFold(ArrayRef<Operation * > ops,const FrozenRewritePatternSet & patterns,bool strict)668 bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
669 const FrozenRewritePatternSet &patterns,
670 bool strict) {
671 if (ops.empty())
672 return false;
673
674 // Start the pattern driver.
675 MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
676 strict);
677 return driver.simplifyLocally(ops);
678 }
679