1 //===- Utils.cpp ---- Utilities for affine dialect transformation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous transformation utilities for the Affine 10 // dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/Utils.h" 15 16 #include "mlir/Dialect/Affine/Analysis/Utils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/IR/Dominance.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "mlir/Transforms/LoopUtils.h" 25 26 using namespace mlir; 27 28 /// Promotes the `then` or the `else` block of `ifOp` (depending on whether 29 /// `elseBlock` is false or true) into `ifOp`'s containing block, and discards 30 /// the rest of the op. 31 static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) { 32 if (elseBlock) 33 assert(ifOp.hasElse() && "else block expected"); 34 35 Block *destBlock = ifOp->getBlock(); 36 Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock(); 37 destBlock->getOperations().splice( 38 Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(), 39 std::prev(srcBlock->end())); 40 ifOp.erase(); 41 } 42 43 /// Returns the outermost affine.for/parallel op that the `ifOp` is invariant 44 /// on. The `ifOp` could be hoisted and placed right before such an operation. 45 /// This method assumes that the ifOp has been canonicalized (to be correct and 46 /// effective). 47 static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) { 48 // Walk up the parents past all for op that this conditional is invariant on. 49 auto ifOperands = ifOp.getOperands(); 50 auto *res = ifOp.getOperation(); 51 while (!isa<FuncOp>(res->getParentOp())) { 52 auto *parentOp = res->getParentOp(); 53 if (auto forOp = dyn_cast<AffineForOp>(parentOp)) { 54 if (llvm::is_contained(ifOperands, forOp.getInductionVar())) 55 break; 56 } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) { 57 for (auto iv : parallelOp.getIVs()) 58 if (llvm::is_contained(ifOperands, iv)) 59 break; 60 } else if (!isa<AffineIfOp>(parentOp)) { 61 // Won't walk up past anything other than affine.for/if ops. 62 break; 63 } 64 // You can always hoist up past any affine.if ops. 65 res = parentOp; 66 } 67 return res; 68 } 69 70 /// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over 71 /// `hoistOverOp`. Returns the new hoisted op if any hoisting happened, 72 /// otherwise the same `ifOp`. 73 static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) { 74 // No hoisting to do. 75 if (hoistOverOp == ifOp) 76 return ifOp; 77 78 // Create the hoisted 'if' first. Then, clone the op we are hoisting over for 79 // the else block. Then drop the else block of the original 'if' in the 'then' 80 // branch while promoting its then block, and analogously drop the 'then' 81 // block of the original 'if' from the 'else' branch while promoting its else 82 // block. 83 BlockAndValueMapping operandMap; 84 OpBuilder b(hoistOverOp); 85 auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(), 86 ifOp.getOperands(), 87 /*elseBlock=*/true); 88 89 // Create a clone of hoistOverOp to use for the else branch of the hoisted 90 // conditional. The else block may get optimized away if empty. 91 Operation *hoistOverOpClone = nullptr; 92 // We use this unique name to identify/find `ifOp`'s clone in the else 93 // version. 94 StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting"); 95 operandMap.clear(); 96 b.setInsertionPointAfter(hoistOverOp); 97 // We'll set an attribute to identify this op in a clone of this sub-tree. 98 ifOp->setAttr(idForIfOp, b.getBoolAttr(true)); 99 hoistOverOpClone = b.clone(*hoistOverOp, operandMap); 100 101 // Promote the 'then' block of the original affine.if in the then version. 102 promoteIfBlock(ifOp, /*elseBlock=*/false); 103 104 // Move the then version to the hoisted if op's 'then' block. 105 auto *thenBlock = hoistedIfOp.getThenBlock(); 106 thenBlock->getOperations().splice(thenBlock->begin(), 107 hoistOverOp->getBlock()->getOperations(), 108 Block::iterator(hoistOverOp)); 109 110 // Find the clone of the original affine.if op in the else version. 111 AffineIfOp ifCloneInElse; 112 hoistOverOpClone->walk([&](AffineIfOp ifClone) { 113 if (!ifClone->getAttr(idForIfOp)) 114 return WalkResult::advance(); 115 ifCloneInElse = ifClone; 116 return WalkResult::interrupt(); 117 }); 118 assert(ifCloneInElse && "if op clone should exist"); 119 // For the else block, promote the else block of the original 'if' if it had 120 // one; otherwise, the op itself is to be erased. 121 if (!ifCloneInElse.hasElse()) 122 ifCloneInElse.erase(); 123 else 124 promoteIfBlock(ifCloneInElse, /*elseBlock=*/true); 125 126 // Move the else version into the else block of the hoisted if op. 127 auto *elseBlock = hoistedIfOp.getElseBlock(); 128 elseBlock->getOperations().splice( 129 elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(), 130 Block::iterator(hoistOverOpClone)); 131 132 return hoistedIfOp; 133 } 134 135 LogicalResult 136 mlir::affineParallelize(AffineForOp forOp, 137 ArrayRef<LoopReduction> parallelReductions) { 138 // Fail early if there are iter arguments that are not reductions. 139 unsigned numReductions = parallelReductions.size(); 140 if (numReductions != forOp.getNumIterOperands()) 141 return failure(); 142 143 Location loc = forOp.getLoc(); 144 OpBuilder outsideBuilder(forOp); 145 AffineMap lowerBoundMap = forOp.getLowerBoundMap(); 146 ValueRange lowerBoundOperands = forOp.getLowerBoundOperands(); 147 AffineMap upperBoundMap = forOp.getUpperBoundMap(); 148 ValueRange upperBoundOperands = forOp.getUpperBoundOperands(); 149 150 // Creating empty 1-D affine.parallel op. 151 auto reducedValues = llvm::to_vector<4>(llvm::map_range( 152 parallelReductions, [](const LoopReduction &red) { return red.value; })); 153 auto reductionKinds = llvm::to_vector<4>(llvm::map_range( 154 parallelReductions, [](const LoopReduction &red) { return red.kind; })); 155 AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>( 156 loc, ValueRange(reducedValues).getTypes(), reductionKinds, 157 llvm::makeArrayRef(lowerBoundMap), lowerBoundOperands, 158 llvm::makeArrayRef(upperBoundMap), upperBoundOperands, 159 llvm::makeArrayRef(forOp.getStep())); 160 // Steal the body of the old affine for op. 161 newPloop.region().takeBody(forOp.region()); 162 Operation *yieldOp = &newPloop.getBody()->back(); 163 164 // Handle the initial values of reductions because the parallel loop always 165 // starts from the neutral value. 166 SmallVector<Value> newResults; 167 newResults.reserve(numReductions); 168 for (unsigned i = 0; i < numReductions; ++i) { 169 Value init = forOp.getIterOperands()[i]; 170 // This works because we are only handling single-op reductions at the 171 // moment. A switch on reduction kind or a mechanism to collect operations 172 // participating in the reduction will be necessary for multi-op reductions. 173 Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp(); 174 assert(reductionOp && "yielded value is expected to be produced by an op"); 175 outsideBuilder.getInsertionBlock()->getOperations().splice( 176 outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(), 177 reductionOp); 178 reductionOp->setOperands({init, newPloop->getResult(i)}); 179 forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0)); 180 } 181 182 // Update the loop terminator to yield reduced values bypassing the reduction 183 // operation itself (now moved outside of the loop) and erase the block 184 // arguments that correspond to reductions. Note that the loop always has one 185 // "main" induction variable whenc coming from a non-parallel for. 186 unsigned numIVs = 1; 187 yieldOp->setOperands(reducedValues); 188 newPloop.getBody()->eraseArguments( 189 llvm::to_vector<4>(llvm::seq<unsigned>(numIVs, numReductions + numIVs))); 190 191 forOp.erase(); 192 return success(); 193 } 194 195 // Returns success if any hoisting happened. 196 LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { 197 // Bail out early if the ifOp returns a result. TODO: Consider how to 198 // properly support this case. 199 if (ifOp.getNumResults() != 0) 200 return failure(); 201 202 // Apply canonicalization patterns and folding - this is necessary for the 203 // hoisting check to be correct (operands should be composed), and to be more 204 // effective (no unused operands). Since the pattern rewriter's folding is 205 // entangled with application of patterns, we may fold/end up erasing the op, 206 // in which case we return with `folded` being set. 207 RewritePatternSet patterns(ifOp.getContext()); 208 AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); 209 bool erased; 210 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 211 (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased); 212 if (erased) { 213 if (folded) 214 *folded = true; 215 return failure(); 216 } 217 if (folded) 218 *folded = false; 219 220 // The folding above should have ensured this, but the affine.if's 221 // canonicalization is missing composition of affine.applys into it. 222 assert(llvm::all_of(ifOp.getOperands(), 223 [](Value v) { 224 return isTopLevelValue(v) || isForInductionVar(v); 225 }) && 226 "operands not composed"); 227 228 // We are going hoist as high as possible. 229 // TODO: this could be customized in the future. 230 auto *hoistOverOp = getOutermostInvariantForOp(ifOp); 231 232 AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp); 233 // Nothing to hoist over. 234 if (hoistedIfOp == ifOp) 235 return failure(); 236 237 // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up 238 // a sequence of affine.fors that are all perfectly nested). 239 (void)applyPatternsAndFoldGreedily( 240 hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(), 241 frozenPatterns); 242 243 return success(); 244 } 245 246 // Return the min expr after replacing the given dim. 247 AffineExpr mlir::substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min, 248 AffineExpr max, bool positivePath) { 249 if (e == dim) 250 return positivePath ? min : max; 251 if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) { 252 AffineExpr lhs = bin.getLHS(); 253 AffineExpr rhs = bin.getRHS(); 254 if (bin.getKind() == mlir::AffineExprKind::Add) 255 return substWithMin(lhs, dim, min, max, positivePath) + 256 substWithMin(rhs, dim, min, max, positivePath); 257 258 auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>(); 259 auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>(); 260 if (c1 && c1.getValue() < 0) 261 return getAffineBinaryOpExpr( 262 bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath)); 263 if (c2 && c2.getValue() < 0) 264 return getAffineBinaryOpExpr( 265 bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2); 266 return getAffineBinaryOpExpr( 267 bin.getKind(), substWithMin(lhs, dim, min, max, positivePath), 268 substWithMin(rhs, dim, min, max, positivePath)); 269 } 270 return e; 271 } 272 273 void mlir::normalizeAffineParallel(AffineParallelOp op) { 274 // Loops with min/max in bounds are not normalized at the moment. 275 if (op.hasMinMaxBounds()) 276 return; 277 278 AffineMap lbMap = op.lowerBoundsMap(); 279 SmallVector<int64_t, 8> steps = op.getSteps(); 280 // No need to do any work if the parallel op is already normalized. 281 bool isAlreadyNormalized = 282 llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) { 283 int64_t step = std::get<0>(tuple); 284 auto lbExpr = 285 std::get<1>(tuple).template dyn_cast<AffineConstantExpr>(); 286 return lbExpr && lbExpr.getValue() == 0 && step == 1; 287 }); 288 if (isAlreadyNormalized) 289 return; 290 291 AffineValueMap ranges; 292 AffineValueMap::difference(op.getUpperBoundsValueMap(), 293 op.getLowerBoundsValueMap(), &ranges); 294 auto builder = OpBuilder::atBlockBegin(op.getBody()); 295 auto zeroExpr = builder.getAffineConstantExpr(0); 296 SmallVector<AffineExpr, 8> lbExprs; 297 SmallVector<AffineExpr, 8> ubExprs; 298 for (unsigned i = 0, e = steps.size(); i < e; ++i) { 299 int64_t step = steps[i]; 300 301 // Adjust the lower bound to be 0. 302 lbExprs.push_back(zeroExpr); 303 304 // Adjust the upper bound expression: 'range / step'. 305 AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step); 306 ubExprs.push_back(ubExpr); 307 308 // Adjust the corresponding IV: 'lb + i * step'. 309 BlockArgument iv = op.getBody()->getArgument(i); 310 AffineExpr lbExpr = lbMap.getResult(i); 311 unsigned nDims = lbMap.getNumDims(); 312 auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step; 313 auto map = AffineMap::get(/*dimCount=*/nDims + 1, 314 /*symbolCount=*/lbMap.getNumSymbols(), expr); 315 316 // Use an 'affine.apply' op that will be simplified later in subsequent 317 // canonicalizations. 318 OperandRange lbOperands = op.getLowerBoundsOperands(); 319 OperandRange dimOperands = lbOperands.take_front(nDims); 320 OperandRange symbolOperands = lbOperands.drop_front(nDims); 321 SmallVector<Value, 8> applyOperands{dimOperands}; 322 applyOperands.push_back(iv); 323 applyOperands.append(symbolOperands.begin(), symbolOperands.end()); 324 auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands); 325 iv.replaceAllUsesExcept(apply, apply); 326 } 327 328 SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1); 329 op.setSteps(newSteps); 330 auto newLowerMap = AffineMap::get( 331 /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext()); 332 op.setLowerBounds({}, newLowerMap); 333 auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(), 334 ubExprs, op.getContext()); 335 op.setUpperBounds(ranges.getOperands(), newUpperMap); 336 } 337 338 /// Normalizes affine.for ops. If the affine.for op has only a single iteration 339 /// only then it is simply promoted, else it is normalized in the traditional 340 /// way, by converting the lower bound to zero and loop step to one. The upper 341 /// bound is set to the trip count of the loop. For now, original loops must 342 /// have lower bound with a single result only. There is no such restriction on 343 /// upper bounds. 344 void mlir::normalizeAffineFor(AffineForOp op) { 345 if (succeeded(promoteIfSingleIteration(op))) 346 return; 347 348 // Check if the forop is already normalized. 349 if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) && 350 (op.getStep() == 1)) 351 return; 352 353 // Check if the lower bound has a single result only. Loops with a max lower 354 // bound can't be normalized without additional support like 355 // affine.execute_region's. If the lower bound does not have a single result 356 // then skip this op. 357 if (op.getLowerBoundMap().getNumResults() != 1) 358 return; 359 360 Location loc = op.getLoc(); 361 OpBuilder opBuilder(op); 362 int64_t origLoopStep = op.getStep(); 363 364 // Calculate upperBound for normalized loop. 365 SmallVector<Value, 4> ubOperands; 366 AffineBound lb = op.getLowerBound(); 367 AffineBound ub = op.getUpperBound(); 368 ubOperands.reserve(ub.getNumOperands() + lb.getNumOperands()); 369 AffineMap origLbMap = lb.getMap(); 370 AffineMap origUbMap = ub.getMap(); 371 372 // Add dimension operands from upper/lower bound. 373 for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) 374 ubOperands.push_back(ub.getOperand(j)); 375 for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j) 376 ubOperands.push_back(lb.getOperand(j)); 377 378 // Add symbol operands from upper/lower bound. 379 for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) 380 ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j)); 381 for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) 382 ubOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); 383 384 // Add original result expressions from lower/upper bound map. 385 SmallVector<AffineExpr, 1> origLbExprs(origLbMap.getResults().begin(), 386 origLbMap.getResults().end()); 387 SmallVector<AffineExpr, 2> origUbExprs(origUbMap.getResults().begin(), 388 origUbMap.getResults().end()); 389 SmallVector<AffineExpr, 4> newUbExprs; 390 391 // The original upperBound can have more than one result. For the new 392 // upperBound of this loop, take difference of all possible combinations of 393 // the ub results and lb result and ceildiv with the loop step. For e.g., 394 // 395 // affine.for %i1 = 0 to min affine_map<(d0)[] -> (d0 + 32, 1024)>(%i0) 396 // will have an upperBound map as, 397 // affine_map<(d0)[] -> (((d0 + 32) - 0) ceildiv 1, (1024 - 0) ceildiv 398 // 1)>(%i0) 399 // 400 // Insert all combinations of upper/lower bound results. 401 for (unsigned i = 0, e = origUbExprs.size(); i < e; ++i) { 402 newUbExprs.push_back( 403 (origUbExprs[i] - origLbExprs[0]).ceilDiv(origLoopStep)); 404 } 405 406 // Construct newUbMap. 407 AffineMap newUbMap = 408 AffineMap::get(origLbMap.getNumDims() + origUbMap.getNumDims(), 409 origLbMap.getNumSymbols() + origUbMap.getNumSymbols(), 410 newUbExprs, opBuilder.getContext()); 411 412 // Normalize the loop. 413 op.setUpperBound(ubOperands, newUbMap); 414 op.setLowerBound({}, opBuilder.getConstantAffineMap(0)); 415 op.setStep(1); 416 417 // Calculate the Value of new loopIV. Create affine.apply for the value of 418 // the loopIV in normalized loop. 419 opBuilder.setInsertionPointToStart(op.getBody()); 420 SmallVector<Value, 4> lbOperands(lb.getOperands().begin(), 421 lb.getOperands().begin() + 422 lb.getMap().getNumDims()); 423 // Add an extra dim operand for loopIV. 424 lbOperands.push_back(op.getInductionVar()); 425 // Add symbol operands from lower bound. 426 for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) 427 lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); 428 429 AffineExpr origIVExpr = opBuilder.getAffineDimExpr(lb.getMap().getNumDims()); 430 AffineExpr newIVExpr = origIVExpr * origLoopStep + origLbMap.getResult(0); 431 AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1, 432 origLbMap.getNumSymbols(), newIVExpr); 433 Operation *newIV = opBuilder.create<AffineApplyOp>(loc, ivMap, lbOperands); 434 op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); 435 } 436 437 /// Ensure that all operations that could be executed after `start` 438 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path 439 /// between the operations) do not have the potential memory effect 440 /// `EffectType` on `memOp`. `memOp` is an operation that reads or writes to 441 /// a memref. For example, if `EffectType` is MemoryEffects::Write, this method 442 /// will check if there is no write to the memory between `start` and `memOp` 443 /// that would change the read within `memOp`. 444 template <typename EffectType, typename T> 445 static bool hasNoInterveningEffect(Operation *start, T memOp) { 446 Value memref = memOp.getMemRef(); 447 bool isOriginalAllocation = memref.getDefiningOp<memref::AllocaOp>() || 448 memref.getDefiningOp<memref::AllocOp>(); 449 450 // A boolean representing whether an intervening operation could have impacted 451 // memOp. 452 bool hasSideEffect = false; 453 454 // Check whether the effect on memOp can be caused by a given operation op. 455 std::function<void(Operation *)> checkOperation = [&](Operation *op) { 456 // If the effect has alreay been found, early exit, 457 if (hasSideEffect) 458 return; 459 460 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) { 461 SmallVector<MemoryEffects::EffectInstance, 1> effects; 462 memEffect.getEffects(effects); 463 464 bool opMayHaveEffect = false; 465 for (auto effect : effects) { 466 // If op causes EffectType on a potentially aliasing location for 467 // memOp, mark as having the effect. 468 if (isa<EffectType>(effect.getEffect())) { 469 if (isOriginalAllocation && effect.getValue() && 470 (effect.getValue().getDefiningOp<memref::AllocaOp>() || 471 effect.getValue().getDefiningOp<memref::AllocOp>())) { 472 if (effect.getValue() != memref) 473 continue; 474 } 475 opMayHaveEffect = true; 476 break; 477 } 478 } 479 480 if (!opMayHaveEffect) 481 return; 482 483 // If the side effect comes from an affine read or write, try to 484 // prove the side effecting `op` cannot reach `memOp`. 485 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { 486 MemRefAccess srcAccess(op); 487 MemRefAccess destAccess(memOp); 488 // Dependence analysis is only correct if both ops operate on the same 489 // memref. 490 if (srcAccess.memref == destAccess.memref) { 491 FlatAffineValueConstraints dependenceConstraints; 492 493 // Number of loops containing the start op and the ending operation. 494 unsigned minSurroundingLoops = 495 getNumCommonSurroundingLoops(*start, *memOp); 496 497 // Number of loops containing the operation `op` which has the 498 // potential memory side effect and can occur on a path between 499 // `start` and `memOp`. 500 unsigned nsLoops = getNumCommonSurroundingLoops(*op, *memOp); 501 502 // For ease, let's consider the case that `op` is a store and we're 503 // looking for other potential stores (e.g `op`) that overwrite memory 504 // after `start`, and before being read in `memOp`. In this case, we 505 // only need to consider other potential stores with depth > 506 // minSurrounding loops since `start` would overwrite any store with a 507 // smaller number of surrounding loops before. 508 unsigned d; 509 for (d = nsLoops + 1; d > minSurroundingLoops; d--) { 510 DependenceResult result = checkMemrefAccessDependence( 511 srcAccess, destAccess, d, &dependenceConstraints, 512 /*dependenceComponents=*/nullptr); 513 if (hasDependence(result)) { 514 hasSideEffect = true; 515 return; 516 } 517 } 518 519 // No side effect was seen, simply return. 520 return; 521 } 522 } 523 hasSideEffect = true; 524 return; 525 } 526 527 if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) { 528 // Recurse into the regions for this op and check whether the internal 529 // operations may have the side effect `EffectType` on memOp. 530 for (Region ®ion : op->getRegions()) 531 for (Block &block : region) 532 for (Operation &op : block) 533 checkOperation(&op); 534 return; 535 } 536 537 // Otherwise, conservatively assume generic operations have the effect 538 // on the operation 539 hasSideEffect = true; 540 }; 541 542 // Check all paths from ancestor op `parent` to the operation `to` for the 543 // effect. It is known that `to` must be contained within `parent`. 544 auto until = [&](Operation *parent, Operation *to) { 545 // TODO check only the paths from `parent` to `to`. 546 // Currently we fallback and check the entire parent op, rather than 547 // just the paths from the parent path, stopping after reaching `to`. 548 // This is conservatively correct, but could be made more aggressive. 549 assert(parent->isAncestor(to)); 550 checkOperation(parent); 551 }; 552 553 // Check for all paths from operation `from` to operation `untilOp` for the 554 // given memory effect. 555 std::function<void(Operation *, Operation *)> recur = 556 [&](Operation *from, Operation *untilOp) { 557 assert( 558 from->getParentRegion()->isAncestor(untilOp->getParentRegion()) && 559 "Checking for side effect between two operations without a common " 560 "ancestor"); 561 562 // If the operations are in different regions, recursively consider all 563 // path from `from` to the parent of `to` and all paths from the parent 564 // of `to` to `to`. 565 if (from->getParentRegion() != untilOp->getParentRegion()) { 566 recur(from, untilOp->getParentOp()); 567 until(untilOp->getParentOp(), untilOp); 568 return; 569 } 570 571 // Now, assuming that `from` and `to` exist in the same region, perform 572 // a CFG traversal to check all the relevant operations. 573 574 // Additional blocks to consider. 575 SmallVector<Block *, 2> todoBlocks; 576 { 577 // First consider the parent block of `from` an check all operations 578 // after `from`. 579 for (auto iter = ++from->getIterator(), end = from->getBlock()->end(); 580 iter != end && &*iter != untilOp; ++iter) { 581 checkOperation(&*iter); 582 } 583 584 // If the parent of `from` doesn't contain `to`, add the successors 585 // to the list of blocks to check. 586 if (untilOp->getBlock() != from->getBlock()) 587 for (Block *succ : from->getBlock()->getSuccessors()) 588 todoBlocks.push_back(succ); 589 } 590 591 SmallPtrSet<Block *, 4> done; 592 // Traverse the CFG until hitting `to`. 593 while (!todoBlocks.empty()) { 594 Block *blk = todoBlocks.pop_back_val(); 595 if (done.count(blk)) 596 continue; 597 done.insert(blk); 598 for (auto &op : *blk) { 599 if (&op == untilOp) 600 break; 601 checkOperation(&op); 602 if (&op == blk->getTerminator()) 603 for (Block *succ : blk->getSuccessors()) 604 todoBlocks.push_back(succ); 605 } 606 } 607 }; 608 recur(start, memOp); 609 return !hasSideEffect; 610 } 611 612 /// Attempt to eliminate loadOp by replacing it with a value stored into memory 613 /// which the load is guaranteed to retrieve. This check involves three 614 /// components: 1) The store and load must be on the same location 2) The store 615 /// must dominate (and therefore must always occur prior to) the load 3) No 616 /// other operations will overwrite the memory loaded between the given load 617 /// and store. If such a value exists, the replaced `loadOp` will be added to 618 /// `loadOpsToErase` and its memref will be added to `memrefsToErase`. 619 static LogicalResult forwardStoreToLoad( 620 AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase, 621 SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo) { 622 623 // The store op candidate for forwarding that satisfies all conditions 624 // to replace the load, if any. 625 Operation *lastWriteStoreOp = nullptr; 626 627 for (auto *user : loadOp.getMemRef().getUsers()) { 628 auto storeOp = dyn_cast<AffineWriteOpInterface>(user); 629 if (!storeOp) 630 continue; 631 MemRefAccess srcAccess(storeOp); 632 MemRefAccess destAccess(loadOp); 633 634 // 1. Check if the store and the load have mathematically equivalent 635 // affine access functions; this implies that they statically refer to the 636 // same single memref element. As an example this filters out cases like: 637 // store %A[%i0 + 1] 638 // load %A[%i0] 639 // store %A[%M] 640 // load %A[%N] 641 // Use the AffineValueMap difference based memref access equality checking. 642 if (srcAccess != destAccess) 643 continue; 644 645 // 2. The store has to dominate the load op to be candidate. 646 if (!domInfo.dominates(storeOp, loadOp)) 647 continue; 648 649 // 3. Ensure there is no intermediate operation which could replace the 650 // value in memory. 651 if (!hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp)) 652 continue; 653 654 // We now have a candidate for forwarding. 655 assert(lastWriteStoreOp == nullptr && 656 "multiple simulataneous replacement stores"); 657 lastWriteStoreOp = storeOp; 658 } 659 660 if (!lastWriteStoreOp) 661 return failure(); 662 663 // Perform the actual store to load forwarding. 664 Value storeVal = 665 cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore(); 666 // Check if 2 values have the same shape. This is needed for affine vector 667 // loads and stores. 668 if (storeVal.getType() != loadOp.getValue().getType()) 669 return failure(); 670 loadOp.getValue().replaceAllUsesWith(storeVal); 671 // Record the memref for a later sweep to optimize away. 672 memrefsToErase.insert(loadOp.getMemRef()); 673 // Record this to erase later. 674 loadOpsToErase.push_back(loadOp); 675 return success(); 676 } 677 678 // This attempts to find stores which have no impact on the final result. 679 // A writing op writeA will be eliminated if there exists an op writeB if 680 // 1) writeA and writeB have mathematically equivalent affine access functions. 681 // 2) writeB postdominates writeA. 682 // 3) There is no potential read between writeA and writeB. 683 static void findUnusedStore(AffineWriteOpInterface writeA, 684 SmallVectorImpl<Operation *> &opsToErase, 685 SmallPtrSetImpl<Value> &memrefsToErase, 686 PostDominanceInfo &postDominanceInfo) { 687 688 for (Operation *user : writeA.getMemRef().getUsers()) { 689 // Only consider writing operations. 690 auto writeB = dyn_cast<AffineWriteOpInterface>(user); 691 if (!writeB) 692 continue; 693 694 // The operations must be distinct. 695 if (writeB == writeA) 696 continue; 697 698 // Both operations must lie in the same region. 699 if (writeB->getParentRegion() != writeA->getParentRegion()) 700 continue; 701 702 // Both operations must write to the same memory. 703 MemRefAccess srcAccess(writeB); 704 MemRefAccess destAccess(writeA); 705 706 if (srcAccess != destAccess) 707 continue; 708 709 // writeB must postdominate writeA. 710 if (!postDominanceInfo.postDominates(writeB, writeA)) 711 continue; 712 713 // There cannot be an operation which reads from memory between 714 // the two writes. 715 if (!hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB)) 716 continue; 717 718 opsToErase.push_back(writeA); 719 break; 720 } 721 } 722 723 // The load to load forwarding / redundant load elimination is similar to the 724 // store to load forwarding. 725 // loadA will be be replaced with loadB if: 726 // 1) loadA and loadB have mathematically equivalent affine access functions. 727 // 2) loadB dominates loadA. 728 // 3) There is no write between loadA and loadB. 729 static void loadCSE(AffineReadOpInterface loadA, 730 SmallVectorImpl<Operation *> &loadOpsToErase, 731 DominanceInfo &domInfo) { 732 SmallVector<AffineReadOpInterface, 4> loadCandidates; 733 for (auto *user : loadA.getMemRef().getUsers()) { 734 auto loadB = dyn_cast<AffineReadOpInterface>(user); 735 if (!loadB || loadB == loadA) 736 continue; 737 738 MemRefAccess srcAccess(loadB); 739 MemRefAccess destAccess(loadA); 740 741 // 1. The accesses have to be to the same location. 742 if (srcAccess != destAccess) { 743 continue; 744 } 745 746 // 2. The store has to dominate the load op to be candidate. 747 if (!domInfo.dominates(loadB, loadA)) 748 continue; 749 750 // 3. There is no write between loadA and loadB. 751 if (!hasNoInterveningEffect<MemoryEffects::Write>(loadB.getOperation(), 752 loadA)) 753 continue; 754 755 // Check if two values have the same shape. This is needed for affine vector 756 // loads. 757 if (loadB.getValue().getType() != loadA.getValue().getType()) 758 continue; 759 760 loadCandidates.push_back(loadB); 761 } 762 763 // Of the legal load candidates, use the one that dominates all others 764 // to minimize the subsequent need to loadCSE 765 Value loadB; 766 for (AffineReadOpInterface option : loadCandidates) { 767 if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) { 768 return depStore == option || 769 domInfo.dominates(option.getOperation(), 770 depStore.getOperation()); 771 })) { 772 loadB = option.getValue(); 773 break; 774 } 775 } 776 777 if (loadB) { 778 loadA.getValue().replaceAllUsesWith(loadB); 779 // Record this to erase later. 780 loadOpsToErase.push_back(loadA); 781 } 782 } 783 784 // The store to load forwarding and load CSE rely on three conditions: 785 // 786 // 1) store/load providing a replacement value and load being replaced need to 787 // have mathematically equivalent affine access functions (checked after full 788 // composition of load/store operands); this implies that they access the same 789 // single memref element for all iterations of the common surrounding loop, 790 // 791 // 2) the store/load op should dominate the load op, 792 // 793 // 3) no operation that may write to memory read by the load being replaced can 794 // occur after executing the instruction (load or store) providing the 795 // replacement value and before the load being replaced (thus potentially 796 // allowing overwriting the memory read by the load). 797 // 798 // The above conditions are simple to check, sufficient, and powerful for most 799 // cases in practice - they are sufficient, but not necessary --- since they 800 // don't reason about loops that are guaranteed to execute at least once or 801 // multiple sources to forward from. 802 // 803 // TODO: more forwarding can be done when support for 804 // loop/conditional live-out SSA values is available. 805 // TODO: do general dead store elimination for memref's. This pass 806 // currently only eliminates the stores only if no other loads/uses (other 807 // than dealloc) remain. 808 // 809 void mlir::affineScalarReplace(FuncOp f, DominanceInfo &domInfo, 810 PostDominanceInfo &postDomInfo) { 811 // Load op's whose results were replaced by those forwarded from stores. 812 SmallVector<Operation *, 8> opsToErase; 813 814 // A list of memref's that are potentially dead / could be eliminated. 815 SmallPtrSet<Value, 4> memrefsToErase; 816 817 // Walk all load's and perform store to load forwarding. 818 f.walk([&](AffineReadOpInterface loadOp) { 819 if (failed( 820 forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) { 821 loadCSE(loadOp, opsToErase, domInfo); 822 } 823 }); 824 825 // Erase all load op's whose results were replaced with store fwd'ed ones. 826 for (auto *op : opsToErase) 827 op->erase(); 828 opsToErase.clear(); 829 830 // Walk all store's and perform unused store elimination 831 f.walk([&](AffineWriteOpInterface storeOp) { 832 findUnusedStore(storeOp, opsToErase, memrefsToErase, postDomInfo); 833 }); 834 // Erase all store op's which don't impact the program 835 for (auto *op : opsToErase) 836 op->erase(); 837 838 // Check if the store fwd'ed memrefs are now left with only stores and can 839 // thus be completely deleted. Note: the canonicalize pass should be able 840 // to do this as well, but we'll do it here since we collected these anyway. 841 for (auto memref : memrefsToErase) { 842 // If the memref hasn't been alloc'ed in this function, skip. 843 Operation *defOp = memref.getDefiningOp(); 844 if (!defOp || !isa<memref::AllocOp>(defOp)) 845 // TODO: if the memref was returned by a 'call' operation, we 846 // could still erase it if the call had no side-effects. 847 continue; 848 if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) { 849 return !isa<AffineWriteOpInterface, memref::DeallocOp>(ownerOp); 850 })) 851 continue; 852 853 // Erase all stores, the dealloc, and the alloc on the memref. 854 for (auto *user : llvm::make_early_inc_range(memref.getUsers())) 855 user->erase(); 856 defOp->erase(); 857 } 858 } 859