1 //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements mlir::applyPatternsAndFoldGreedily. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 14 #include "mlir/Interfaces/SideEffectInterfaces.h" 15 #include "mlir/Rewrite/PatternApplicator.h" 16 #include "mlir/Transforms/FoldUtils.h" 17 #include "mlir/Transforms/RegionUtils.h" 18 #include "llvm/ADT/DenseMap.h" 19 #include "llvm/Support/CommandLine.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 using namespace mlir; 24 25 #define DEBUG_TYPE "pattern-matcher" 26 27 /// The max number of iterations scanning for pattern match. 28 static unsigned maxPatternMatchIterations = 10; 29 30 //===----------------------------------------------------------------------===// 31 // GreedyPatternRewriteDriver 32 //===----------------------------------------------------------------------===// 33 34 namespace { 35 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly 36 /// applies the locally optimal patterns in a roughly "bottom up" way. 37 class GreedyPatternRewriteDriver : public PatternRewriter { 38 public: 39 explicit GreedyPatternRewriteDriver(MLIRContext *ctx, 40 const FrozenRewritePatternSet &patterns, 41 bool useTopDownTraversal) 42 : PatternRewriter(ctx), matcher(patterns), folder(ctx), 43 useTopDownTraversal(useTopDownTraversal) { 44 worklist.reserve(64); 45 46 // Apply a simple cost model based solely on pattern benefit. 47 matcher.applyDefaultCostModel(); 48 } 49 50 bool simplify(MutableArrayRef<Region> regions, int maxIterations); 51 52 void addToWorklist(Operation *op) { 53 // Check to see if the worklist already contains this op. 54 if (worklistMap.count(op)) 55 return; 56 57 worklistMap[op] = worklist.size(); 58 worklist.push_back(op); 59 } 60 61 Operation *popFromWorklist() { 62 auto *op = worklist.back(); 63 worklist.pop_back(); 64 65 // This operation is no longer in the worklist, keep worklistMap up to date. 66 if (op) 67 worklistMap.erase(op); 68 return op; 69 } 70 71 /// If the specified operation is in the worklist, remove it. If not, this is 72 /// a no-op. 73 void removeFromWorklist(Operation *op) { 74 auto it = worklistMap.find(op); 75 if (it != worklistMap.end()) { 76 assert(worklist[it->second] == op && "malformed worklist data structure"); 77 worklist[it->second] = nullptr; 78 worklistMap.erase(it); 79 } 80 } 81 82 // These are hooks implemented for PatternRewriter. 83 protected: 84 // Implement the hook for inserting operations, and make sure that newly 85 // inserted ops are added to the worklist for processing. 86 void notifyOperationInserted(Operation *op) override { addToWorklist(op); } 87 88 // If an operation is about to be removed, make sure it is not in our 89 // worklist anymore because we'd get dangling references to it. 90 void notifyOperationRemoved(Operation *op) override { 91 addToWorklist(op->getOperands()); 92 op->walk([this](Operation *operation) { 93 removeFromWorklist(operation); 94 folder.notifyRemoval(operation); 95 }); 96 } 97 98 // When the root of a pattern is about to be replaced, it can trigger 99 // simplifications to its users - make sure to add them to the worklist 100 // before the root is changed. 101 void notifyRootReplaced(Operation *op) override { 102 for (auto result : op->getResults()) 103 for (auto *user : result.getUsers()) 104 addToWorklist(user); 105 } 106 107 private: 108 // Look over the provided operands for any defining operations that should 109 // be re-added to the worklist. This function should be called when an 110 // operation is modified or removed, as it may trigger further 111 // simplifications. 112 template <typename Operands> 113 void addToWorklist(Operands &&operands) { 114 for (Value operand : operands) { 115 // If the use count of this operand is now < 2, we re-add the defining 116 // operation to the worklist. 117 // TODO: This is based on the fact that zero use operations 118 // may be deleted, and that single use values often have more 119 // canonicalization opportunities. 120 if (!operand || (!operand.use_empty() && !operand.hasOneUse())) 121 continue; 122 if (auto *defInst = operand.getDefiningOp()) 123 addToWorklist(defInst); 124 } 125 } 126 127 /// The low-level pattern applicator. 128 PatternApplicator matcher; 129 130 /// The worklist for this transformation keeps track of the operations that 131 /// need to be revisited, plus their index in the worklist. This allows us to 132 /// efficiently remove operations from the worklist when they are erased, even 133 /// if they aren't the root of a pattern. 134 std::vector<Operation *> worklist; 135 DenseMap<Operation *, unsigned> worklistMap; 136 137 /// Non-pattern based folder for operations. 138 OperationFolder folder; 139 140 // Whether to use top-down or bottom-up traversal order. 141 bool useTopDownTraversal; 142 }; 143 } // end anonymous namespace 144 145 /// Performs the rewrites while folding and erasing any dead ops. Returns true 146 /// if the rewrite converges in `maxIterations`. 147 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions, 148 int maxIterations) { 149 // Perform a prepass over the IR to discover constants. 150 for (auto ®ion : regions) 151 folder.processExistingConstants(region); 152 153 bool changed = false; 154 int iteration = 0; 155 do { 156 worklist.clear(); 157 worklistMap.clear(); 158 159 // Add all nested operations to the worklist in preorder. 160 for (auto ®ion : regions) 161 if (useTopDownTraversal) 162 region.walk<WalkOrder::PreOrder>( 163 [this](Operation *op) { worklist.push_back(op); }); 164 else 165 region.walk([this](Operation *op) { addToWorklist(op); }); 166 167 if (useTopDownTraversal) { 168 // Reverse the list so our pop-back loop processes them in-order. 169 std::reverse(worklist.begin(), worklist.end()); 170 // Remember the reverse index. 171 for (unsigned i = 0, e = worklist.size(); i != e; ++i) 172 worklistMap[worklist[i]] = i; 173 } 174 175 // These are scratch vectors used in the folding loop below. 176 SmallVector<Value, 8> originalOperands, resultValues; 177 178 changed = false; 179 while (!worklist.empty()) { 180 auto *op = popFromWorklist(); 181 182 // Nulls get added to the worklist when operations are removed, ignore 183 // them. 184 if (op == nullptr) 185 continue; 186 187 // If the operation is trivially dead - remove it. 188 if (isOpTriviallyDead(op)) { 189 notifyOperationRemoved(op); 190 op->erase(); 191 changed = true; 192 continue; 193 } 194 195 // Collects all the operands and result uses of the given `op` into work 196 // list. Also remove `op` and nested ops from worklist. 197 originalOperands.assign(op->operand_begin(), op->operand_end()); 198 auto preReplaceAction = [&](Operation *op) { 199 // Add the operands to the worklist for visitation. 200 addToWorklist(originalOperands); 201 202 // Add all the users of the result to the worklist so we make sure 203 // to revisit them. 204 for (auto result : op->getResults()) 205 for (auto *userOp : result.getUsers()) 206 addToWorklist(userOp); 207 208 notifyOperationRemoved(op); 209 }; 210 211 // Add the given operation to the worklist. 212 auto collectOps = [this](Operation *op) { addToWorklist(op); }; 213 214 // Try to fold this op. 215 bool inPlaceUpdate; 216 if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, 217 &inPlaceUpdate)))) { 218 changed = true; 219 if (!inPlaceUpdate) 220 continue; 221 } 222 223 // Try to match one of the patterns. The rewriter is automatically 224 // notified of any necessary changes, so there is nothing else to do here. 225 changed |= succeeded(matcher.matchAndRewrite(op, *this)); 226 } 227 228 // After applying patterns, make sure that the CFG of each of the regions is 229 // kept up to date. 230 changed |= succeeded(simplifyRegions(*this, regions)); 231 } while (changed && ++iteration < maxIterations); 232 233 // Whether the rewrite converges, i.e. wasn't changed in the last iteration. 234 return !changed; 235 } 236 237 /// Rewrite the regions of the specified operation, which must be isolated from 238 /// above, by repeatedly applying the highest benefit patterns in a greedy 239 /// work-list driven manner. Return success if no more patterns can be matched 240 /// in the result operation regions. Note: This does not apply patterns to the 241 /// top-level operation itself. 242 /// 243 LogicalResult 244 mlir::applyPatternsAndFoldGreedily(Operation *op, 245 const FrozenRewritePatternSet &patterns, 246 bool useTopDownTraversal) { 247 return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations, 248 useTopDownTraversal); 249 } 250 LogicalResult mlir::applyPatternsAndFoldGreedily( 251 Operation *op, const FrozenRewritePatternSet &patterns, 252 unsigned maxIterations, bool useTopDownTraversal) { 253 return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations, 254 useTopDownTraversal); 255 } 256 /// Rewrite the given regions, which must be isolated from above. 257 LogicalResult 258 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions, 259 const FrozenRewritePatternSet &patterns, 260 bool useTopDownTraversal) { 261 return applyPatternsAndFoldGreedily( 262 regions, patterns, maxPatternMatchIterations, useTopDownTraversal); 263 } 264 LogicalResult mlir::applyPatternsAndFoldGreedily( 265 MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns, 266 unsigned maxIterations, bool useTopDownTraversal) { 267 if (regions.empty()) 268 return success(); 269 270 // The top-level operation must be known to be isolated from above to 271 // prevent performing canonicalizations on operations defined at or above 272 // the region containing 'op'. 273 auto regionIsIsolated = [](Region ®ion) { 274 return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>(); 275 }; 276 (void)regionIsIsolated; 277 assert(llvm::all_of(regions, regionIsIsolated) && 278 "patterns can only be applied to operations IsolatedFromAbove"); 279 280 // Start the pattern driver. 281 GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, 282 useTopDownTraversal); 283 bool converged = driver.simplify(regions, maxIterations); 284 LLVM_DEBUG(if (!converged) { 285 llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " 286 << maxIterations << " times\n"; 287 }); 288 return success(converged); 289 } 290 291 //===----------------------------------------------------------------------===// 292 // OpPatternRewriteDriver 293 //===----------------------------------------------------------------------===// 294 295 namespace { 296 /// This is a simple driver for the PatternMatcher to apply patterns and perform 297 /// folding on a single op. It repeatedly applies locally optimal patterns. 298 class OpPatternRewriteDriver : public PatternRewriter { 299 public: 300 explicit OpPatternRewriteDriver(MLIRContext *ctx, 301 const FrozenRewritePatternSet &patterns) 302 : PatternRewriter(ctx), matcher(patterns), folder(ctx) { 303 // Apply a simple cost model based solely on pattern benefit. 304 matcher.applyDefaultCostModel(); 305 } 306 307 /// Performs the rewrites and folding only on `op`. The simplification 308 /// converges if the op is erased as a result of being folded, replaced, or 309 /// dead, or no more changes happen in an iteration. Returns success if the 310 /// rewrite converges in `maxIterations`. `erased` is set to true if `op` gets 311 /// erased. 312 LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased); 313 314 // These are hooks implemented for PatternRewriter. 315 protected: 316 /// If an operation is about to be removed, mark it so that we can let clients 317 /// know. 318 void notifyOperationRemoved(Operation *op) override { 319 opErasedViaPatternRewrites = true; 320 } 321 322 // When a root is going to be replaced, its removal will be notified as well. 323 // So there is nothing to do here. 324 void notifyRootReplaced(Operation *op) override {} 325 326 private: 327 /// The low-level pattern applicator. 328 PatternApplicator matcher; 329 330 /// Non-pattern based folder for operations. 331 OperationFolder folder; 332 333 /// Set to true if the operation has been erased via pattern rewrites. 334 bool opErasedViaPatternRewrites = false; 335 }; 336 337 } // anonymous namespace 338 339 LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op, 340 int maxIterations, 341 bool &erased) { 342 bool changed = false; 343 erased = false; 344 opErasedViaPatternRewrites = false; 345 int i = 0; 346 // Iterate until convergence or until maxIterations. Deletion of the op as 347 // a result of being dead or folded is convergence. 348 do { 349 changed = false; 350 351 // If the operation is trivially dead - remove it. 352 if (isOpTriviallyDead(op)) { 353 op->erase(); 354 erased = true; 355 return success(); 356 } 357 358 // Try to fold this op. 359 bool inPlaceUpdate; 360 if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr, 361 /*preReplaceAction=*/nullptr, 362 &inPlaceUpdate))) { 363 changed = true; 364 if (!inPlaceUpdate) { 365 erased = true; 366 return success(); 367 } 368 } 369 370 // Try to match one of the patterns. The rewriter is automatically 371 // notified of any necessary changes, so there is nothing else to do here. 372 changed |= succeeded(matcher.matchAndRewrite(op, *this)); 373 if ((erased = opErasedViaPatternRewrites)) 374 return success(); 375 } while (changed && ++i < maxIterations); 376 377 // Whether the rewrite converges, i.e. wasn't changed in the last iteration. 378 return failure(changed); 379 } 380 381 /// Rewrites only `op` using the supplied canonicalization patterns and 382 /// folding. `erased` is set to true if the op is erased as a result of being 383 /// folded, replaced, or dead. 384 LogicalResult mlir::applyOpPatternsAndFold( 385 Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) { 386 // Start the pattern driver. 387 OpPatternRewriteDriver driver(op->getContext(), patterns); 388 bool opErased; 389 LogicalResult converged = 390 driver.simplifyLocally(op, maxPatternMatchIterations, opErased); 391 if (erased) 392 *erased = opErased; 393 LLVM_DEBUG(if (failed(converged)) { 394 llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " 395 << maxPatternMatchIterations << " times"; 396 }); 397 return converged; 398 } 399