1 //===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous transformation utilities for the Affine
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/Utils.h"
15 
16 #include "mlir/Dialect/Affine/Analysis/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/IR/IntegerSet.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "mlir/Transforms/LoopUtils.h"
25 
26 using namespace mlir;
27 
28 /// Promotes the `then` or the `else` block of `ifOp` (depending on whether
29 /// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
30 /// the rest of the op.
31 static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
32   if (elseBlock)
33     assert(ifOp.hasElse() && "else block expected");
34 
35   Block *destBlock = ifOp->getBlock();
36   Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
37   destBlock->getOperations().splice(
38       Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(),
39       std::prev(srcBlock->end()));
40   ifOp.erase();
41 }
42 
43 /// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
44 /// on. The `ifOp` could be hoisted and placed right before such an operation.
45 /// This method assumes that the ifOp has been canonicalized (to be correct and
46 /// effective).
47 static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
48   // Walk up the parents past all for op that this conditional is invariant on.
49   auto ifOperands = ifOp.getOperands();
50   auto *res = ifOp.getOperation();
51   while (!isa<FuncOp>(res->getParentOp())) {
52     auto *parentOp = res->getParentOp();
53     if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
54       if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
55         break;
56     } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
57       for (auto iv : parallelOp.getIVs())
58         if (llvm::is_contained(ifOperands, iv))
59           break;
60     } else if (!isa<AffineIfOp>(parentOp)) {
61       // Won't walk up past anything other than affine.for/if ops.
62       break;
63     }
64     // You can always hoist up past any affine.if ops.
65     res = parentOp;
66   }
67   return res;
68 }
69 
70 /// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
71 /// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
72 /// otherwise the same `ifOp`.
73 static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
74   // No hoisting to do.
75   if (hoistOverOp == ifOp)
76     return ifOp;
77 
78   // Create the hoisted 'if' first. Then, clone the op we are hoisting over for
79   // the else block. Then drop the else block of the original 'if' in the 'then'
80   // branch while promoting its then block, and analogously drop the 'then'
81   // block of the original 'if' from the 'else' branch while promoting its else
82   // block.
83   BlockAndValueMapping operandMap;
84   OpBuilder b(hoistOverOp);
85   auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
86                                           ifOp.getOperands(),
87                                           /*elseBlock=*/true);
88 
89   // Create a clone of hoistOverOp to use for the else branch of the hoisted
90   // conditional. The else block may get optimized away if empty.
91   Operation *hoistOverOpClone = nullptr;
92   // We use this unique name to identify/find  `ifOp`'s clone in the else
93   // version.
94   StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting");
95   operandMap.clear();
96   b.setInsertionPointAfter(hoistOverOp);
97   // We'll set an attribute to identify this op in a clone of this sub-tree.
98   ifOp->setAttr(idForIfOp, b.getBoolAttr(true));
99   hoistOverOpClone = b.clone(*hoistOverOp, operandMap);
100 
101   // Promote the 'then' block of the original affine.if in the then version.
102   promoteIfBlock(ifOp, /*elseBlock=*/false);
103 
104   // Move the then version to the hoisted if op's 'then' block.
105   auto *thenBlock = hoistedIfOp.getThenBlock();
106   thenBlock->getOperations().splice(thenBlock->begin(),
107                                     hoistOverOp->getBlock()->getOperations(),
108                                     Block::iterator(hoistOverOp));
109 
110   // Find the clone of the original affine.if op in the else version.
111   AffineIfOp ifCloneInElse;
112   hoistOverOpClone->walk([&](AffineIfOp ifClone) {
113     if (!ifClone->getAttr(idForIfOp))
114       return WalkResult::advance();
115     ifCloneInElse = ifClone;
116     return WalkResult::interrupt();
117   });
118   assert(ifCloneInElse && "if op clone should exist");
119   // For the else block, promote the else block of the original 'if' if it had
120   // one; otherwise, the op itself is to be erased.
121   if (!ifCloneInElse.hasElse())
122     ifCloneInElse.erase();
123   else
124     promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
125 
126   // Move the else version into the else block of the hoisted if op.
127   auto *elseBlock = hoistedIfOp.getElseBlock();
128   elseBlock->getOperations().splice(
129       elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
130       Block::iterator(hoistOverOpClone));
131 
132   return hoistedIfOp;
133 }
134 
135 LogicalResult
136 mlir::affineParallelize(AffineForOp forOp,
137                         ArrayRef<LoopReduction> parallelReductions) {
138   // Fail early if there are iter arguments that are not reductions.
139   unsigned numReductions = parallelReductions.size();
140   if (numReductions != forOp.getNumIterOperands())
141     return failure();
142 
143   Location loc = forOp.getLoc();
144   OpBuilder outsideBuilder(forOp);
145   AffineMap lowerBoundMap = forOp.getLowerBoundMap();
146   ValueRange lowerBoundOperands = forOp.getLowerBoundOperands();
147   AffineMap upperBoundMap = forOp.getUpperBoundMap();
148   ValueRange upperBoundOperands = forOp.getUpperBoundOperands();
149 
150   // Creating empty 1-D affine.parallel op.
151   auto reducedValues = llvm::to_vector<4>(llvm::map_range(
152       parallelReductions, [](const LoopReduction &red) { return red.value; }));
153   auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
154       parallelReductions, [](const LoopReduction &red) { return red.kind; }));
155   AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
156       loc, ValueRange(reducedValues).getTypes(), reductionKinds,
157       llvm::makeArrayRef(lowerBoundMap), lowerBoundOperands,
158       llvm::makeArrayRef(upperBoundMap), upperBoundOperands,
159       llvm::makeArrayRef(forOp.getStep()));
160   // Steal the body of the old affine for op.
161   newPloop.region().takeBody(forOp.region());
162   Operation *yieldOp = &newPloop.getBody()->back();
163 
164   // Handle the initial values of reductions because the parallel loop always
165   // starts from the neutral value.
166   SmallVector<Value> newResults;
167   newResults.reserve(numReductions);
168   for (unsigned i = 0; i < numReductions; ++i) {
169     Value init = forOp.getIterOperands()[i];
170     // This works because we are only handling single-op reductions at the
171     // moment. A switch on reduction kind or a mechanism to collect operations
172     // participating in the reduction will be necessary for multi-op reductions.
173     Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp();
174     assert(reductionOp && "yielded value is expected to be produced by an op");
175     outsideBuilder.getInsertionBlock()->getOperations().splice(
176         outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(),
177         reductionOp);
178     reductionOp->setOperands({init, newPloop->getResult(i)});
179     forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0));
180   }
181 
182   // Update the loop terminator to yield reduced values bypassing the reduction
183   // operation itself (now moved outside of the loop) and erase the block
184   // arguments that correspond to reductions. Note that the loop always has one
185   // "main" induction variable whenc coming from a non-parallel for.
186   unsigned numIVs = 1;
187   yieldOp->setOperands(reducedValues);
188   newPloop.getBody()->eraseArguments(
189       llvm::to_vector<4>(llvm::seq<unsigned>(numIVs, numReductions + numIVs)));
190 
191   forOp.erase();
192   return success();
193 }
194 
195 // Returns success if any hoisting happened.
196 LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
197   // Bail out early if the ifOp returns a result.  TODO: Consider how to
198   // properly support this case.
199   if (ifOp.getNumResults() != 0)
200     return failure();
201 
202   // Apply canonicalization patterns and folding - this is necessary for the
203   // hoisting check to be correct (operands should be composed), and to be more
204   // effective (no unused operands). Since the pattern rewriter's folding is
205   // entangled with application of patterns, we may fold/end up erasing the op,
206   // in which case we return with `folded` being set.
207   RewritePatternSet patterns(ifOp.getContext());
208   AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
209   bool erased;
210   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
211   (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased);
212   if (erased) {
213     if (folded)
214       *folded = true;
215     return failure();
216   }
217   if (folded)
218     *folded = false;
219 
220   // The folding above should have ensured this, but the affine.if's
221   // canonicalization is missing composition of affine.applys into it.
222   assert(llvm::all_of(ifOp.getOperands(),
223                       [](Value v) {
224                         return isTopLevelValue(v) || isForInductionVar(v);
225                       }) &&
226          "operands not composed");
227 
228   // We are going hoist as high as possible.
229   // TODO: this could be customized in the future.
230   auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
231 
232   AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
233   // Nothing to hoist over.
234   if (hoistedIfOp == ifOp)
235     return failure();
236 
237   // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
238   // a sequence of affine.fors that are all perfectly nested).
239   (void)applyPatternsAndFoldGreedily(
240       hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
241       frozenPatterns);
242 
243   return success();
244 }
245 
246 // Return the min expr after replacing the given dim.
247 AffineExpr mlir::substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
248                               AffineExpr max, bool positivePath) {
249   if (e == dim)
250     return positivePath ? min : max;
251   if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) {
252     AffineExpr lhs = bin.getLHS();
253     AffineExpr rhs = bin.getRHS();
254     if (bin.getKind() == mlir::AffineExprKind::Add)
255       return substWithMin(lhs, dim, min, max, positivePath) +
256              substWithMin(rhs, dim, min, max, positivePath);
257 
258     auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>();
259     auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>();
260     if (c1 && c1.getValue() < 0)
261       return getAffineBinaryOpExpr(
262           bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath));
263     if (c2 && c2.getValue() < 0)
264       return getAffineBinaryOpExpr(
265           bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2);
266     return getAffineBinaryOpExpr(
267         bin.getKind(), substWithMin(lhs, dim, min, max, positivePath),
268         substWithMin(rhs, dim, min, max, positivePath));
269   }
270   return e;
271 }
272 
273 void mlir::normalizeAffineParallel(AffineParallelOp op) {
274   // Loops with min/max in bounds are not normalized at the moment.
275   if (op.hasMinMaxBounds())
276     return;
277 
278   AffineMap lbMap = op.lowerBoundsMap();
279   SmallVector<int64_t, 8> steps = op.getSteps();
280   // No need to do any work if the parallel op is already normalized.
281   bool isAlreadyNormalized =
282       llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
283         int64_t step = std::get<0>(tuple);
284         auto lbExpr =
285             std::get<1>(tuple).template dyn_cast<AffineConstantExpr>();
286         return lbExpr && lbExpr.getValue() == 0 && step == 1;
287       });
288   if (isAlreadyNormalized)
289     return;
290 
291   AffineValueMap ranges;
292   AffineValueMap::difference(op.getUpperBoundsValueMap(),
293                              op.getLowerBoundsValueMap(), &ranges);
294   auto builder = OpBuilder::atBlockBegin(op.getBody());
295   auto zeroExpr = builder.getAffineConstantExpr(0);
296   SmallVector<AffineExpr, 8> lbExprs;
297   SmallVector<AffineExpr, 8> ubExprs;
298   for (unsigned i = 0, e = steps.size(); i < e; ++i) {
299     int64_t step = steps[i];
300 
301     // Adjust the lower bound to be 0.
302     lbExprs.push_back(zeroExpr);
303 
304     // Adjust the upper bound expression: 'range / step'.
305     AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step);
306     ubExprs.push_back(ubExpr);
307 
308     // Adjust the corresponding IV: 'lb + i * step'.
309     BlockArgument iv = op.getBody()->getArgument(i);
310     AffineExpr lbExpr = lbMap.getResult(i);
311     unsigned nDims = lbMap.getNumDims();
312     auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step;
313     auto map = AffineMap::get(/*dimCount=*/nDims + 1,
314                               /*symbolCount=*/lbMap.getNumSymbols(), expr);
315 
316     // Use an 'affine.apply' op that will be simplified later in subsequent
317     // canonicalizations.
318     OperandRange lbOperands = op.getLowerBoundsOperands();
319     OperandRange dimOperands = lbOperands.take_front(nDims);
320     OperandRange symbolOperands = lbOperands.drop_front(nDims);
321     SmallVector<Value, 8> applyOperands{dimOperands};
322     applyOperands.push_back(iv);
323     applyOperands.append(symbolOperands.begin(), symbolOperands.end());
324     auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
325     iv.replaceAllUsesExcept(apply, apply);
326   }
327 
328   SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
329   op.setSteps(newSteps);
330   auto newLowerMap = AffineMap::get(
331       /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext());
332   op.setLowerBounds({}, newLowerMap);
333   auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(),
334                                     ubExprs, op.getContext());
335   op.setUpperBounds(ranges.getOperands(), newUpperMap);
336 }
337 
338 /// Normalizes affine.for ops. If the affine.for op has only a single iteration
339 /// only then it is simply promoted, else it is normalized in the traditional
340 /// way, by converting the lower bound to zero and loop step to one. The upper
341 /// bound is set to the trip count of the loop. For now, original loops must
342 /// have lower bound with a single result only. There is no such restriction on
343 /// upper bounds.
344 void mlir::normalizeAffineFor(AffineForOp op) {
345   if (succeeded(promoteIfSingleIteration(op)))
346     return;
347 
348   // Check if the forop is already normalized.
349   if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) &&
350       (op.getStep() == 1))
351     return;
352 
353   // Check if the lower bound has a single result only. Loops with a max lower
354   // bound can't be normalized without additional support like
355   // affine.execute_region's. If the lower bound does not have a single result
356   // then skip this op.
357   if (op.getLowerBoundMap().getNumResults() != 1)
358     return;
359 
360   Location loc = op.getLoc();
361   OpBuilder opBuilder(op);
362   int64_t origLoopStep = op.getStep();
363 
364   // Calculate upperBound for normalized loop.
365   SmallVector<Value, 4> ubOperands;
366   AffineBound lb = op.getLowerBound();
367   AffineBound ub = op.getUpperBound();
368   ubOperands.reserve(ub.getNumOperands() + lb.getNumOperands());
369   AffineMap origLbMap = lb.getMap();
370   AffineMap origUbMap = ub.getMap();
371 
372   // Add dimension operands from upper/lower bound.
373   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
374     ubOperands.push_back(ub.getOperand(j));
375   for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j)
376     ubOperands.push_back(lb.getOperand(j));
377 
378   // Add symbol operands from upper/lower bound.
379   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
380     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
381   for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
382     ubOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
383 
384   // Add original result expressions from lower/upper bound map.
385   SmallVector<AffineExpr, 1> origLbExprs(origLbMap.getResults().begin(),
386                                          origLbMap.getResults().end());
387   SmallVector<AffineExpr, 2> origUbExprs(origUbMap.getResults().begin(),
388                                          origUbMap.getResults().end());
389   SmallVector<AffineExpr, 4> newUbExprs;
390 
391   // The original upperBound can have more than one result. For the new
392   // upperBound of this loop, take difference of all possible combinations of
393   // the ub results and lb result and ceildiv with the loop step. For e.g.,
394   //
395   //  affine.for %i1 = 0 to min affine_map<(d0)[] -> (d0 + 32, 1024)>(%i0)
396   //  will have an upperBound map as,
397   //  affine_map<(d0)[] -> (((d0 + 32) - 0) ceildiv 1, (1024 - 0) ceildiv
398   //  1)>(%i0)
399   //
400   // Insert all combinations of upper/lower bound results.
401   for (unsigned i = 0, e = origUbExprs.size(); i < e; ++i) {
402     newUbExprs.push_back(
403         (origUbExprs[i] - origLbExprs[0]).ceilDiv(origLoopStep));
404   }
405 
406   // Construct newUbMap.
407   AffineMap newUbMap =
408       AffineMap::get(origLbMap.getNumDims() + origUbMap.getNumDims(),
409                      origLbMap.getNumSymbols() + origUbMap.getNumSymbols(),
410                      newUbExprs, opBuilder.getContext());
411 
412   // Normalize the loop.
413   op.setUpperBound(ubOperands, newUbMap);
414   op.setLowerBound({}, opBuilder.getConstantAffineMap(0));
415   op.setStep(1);
416 
417   // Calculate the Value of new loopIV. Create affine.apply for the value of
418   // the loopIV in normalized loop.
419   opBuilder.setInsertionPointToStart(op.getBody());
420   SmallVector<Value, 4> lbOperands(lb.getOperands().begin(),
421                                    lb.getOperands().begin() +
422                                        lb.getMap().getNumDims());
423   // Add an extra dim operand for loopIV.
424   lbOperands.push_back(op.getInductionVar());
425   // Add symbol operands from lower bound.
426   for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
427     lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
428 
429   AffineExpr origIVExpr = opBuilder.getAffineDimExpr(lb.getMap().getNumDims());
430   AffineExpr newIVExpr = origIVExpr * origLoopStep + origLbMap.getResult(0);
431   AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1,
432                                    origLbMap.getNumSymbols(), newIVExpr);
433   Operation *newIV = opBuilder.create<AffineApplyOp>(loc, ivMap, lbOperands);
434   op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
435 }
436 
437 /// Ensure that all operations that could be executed after `start`
438 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
439 /// between the operations) do not have the potential memory effect
440 /// `EffectType` on `memOp`. `memOp`  is an operation that reads or writes to
441 /// a memref. For example, if `EffectType` is MemoryEffects::Write, this method
442 /// will check if there is no write to the memory between `start` and `memOp`
443 /// that would change the read within `memOp`.
444 template <typename EffectType, typename T>
445 static bool hasNoInterveningEffect(Operation *start, T memOp) {
446   Value memref = memOp.getMemRef();
447   bool isOriginalAllocation = memref.getDefiningOp<memref::AllocaOp>() ||
448                               memref.getDefiningOp<memref::AllocOp>();
449 
450   // A boolean representing whether an intervening operation could have impacted
451   // memOp.
452   bool hasSideEffect = false;
453 
454   // Check whether the effect on memOp can be caused by a given operation op.
455   std::function<void(Operation *)> checkOperation = [&](Operation *op) {
456     // If the effect has alreay been found, early exit,
457     if (hasSideEffect)
458       return;
459 
460     if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
461       SmallVector<MemoryEffects::EffectInstance, 1> effects;
462       memEffect.getEffects(effects);
463 
464       bool opMayHaveEffect = false;
465       for (auto effect : effects) {
466         // If op causes EffectType on a potentially aliasing location for
467         // memOp, mark as having the effect.
468         if (isa<EffectType>(effect.getEffect())) {
469           if (isOriginalAllocation && effect.getValue() &&
470               (effect.getValue().getDefiningOp<memref::AllocaOp>() ||
471                effect.getValue().getDefiningOp<memref::AllocOp>())) {
472             if (effect.getValue() != memref)
473               continue;
474           }
475           opMayHaveEffect = true;
476           break;
477         }
478       }
479 
480       if (!opMayHaveEffect)
481         return;
482 
483       // If the side effect comes from an affine read or write, try to
484       // prove the side effecting `op` cannot reach `memOp`.
485       if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
486         MemRefAccess srcAccess(op);
487         MemRefAccess destAccess(memOp);
488         // Dependence analysis is only correct if both ops operate on the same
489         // memref.
490         if (srcAccess.memref == destAccess.memref) {
491           FlatAffineValueConstraints dependenceConstraints;
492 
493           // Number of loops containing the start op and the ending operation.
494           unsigned minSurroundingLoops =
495               getNumCommonSurroundingLoops(*start, *memOp);
496 
497           // Number of loops containing the operation `op` which has the
498           // potential memory side effect and can occur on a path between
499           // `start` and `memOp`.
500           unsigned nsLoops = getNumCommonSurroundingLoops(*op, *memOp);
501 
502           // For ease, let's consider the case that `op` is a store and we're
503           // looking for other potential stores (e.g `op`) that overwrite memory
504           // after `start`, and before being read in `memOp`. In this case, we
505           // only need to consider other potential stores with depth >
506           // minSurrounding loops since `start` would overwrite any store with a
507           // smaller number of surrounding loops before.
508           unsigned d;
509           for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
510             DependenceResult result = checkMemrefAccessDependence(
511                 srcAccess, destAccess, d, &dependenceConstraints,
512                 /*dependenceComponents=*/nullptr);
513             if (hasDependence(result)) {
514               hasSideEffect = true;
515               return;
516             }
517           }
518 
519           // No side effect was seen, simply return.
520           return;
521         }
522       }
523       hasSideEffect = true;
524       return;
525     }
526 
527     if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
528       // Recurse into the regions for this op and check whether the internal
529       // operations may have the side effect `EffectType` on memOp.
530       for (Region &region : op->getRegions())
531         for (Block &block : region)
532           for (Operation &op : block)
533             checkOperation(&op);
534       return;
535     }
536 
537     // Otherwise, conservatively assume generic operations have the effect
538     // on the operation
539     hasSideEffect = true;
540   };
541 
542   // Check all paths from ancestor op `parent` to the operation `to` for the
543   // effect. It is known that `to` must be contained within `parent`.
544   auto until = [&](Operation *parent, Operation *to) {
545     // TODO check only the paths from `parent` to `to`.
546     // Currently we fallback and check the entire parent op, rather than
547     // just the paths from the parent path, stopping after reaching `to`.
548     // This is conservatively correct, but could be made more aggressive.
549     assert(parent->isAncestor(to));
550     checkOperation(parent);
551   };
552 
553   // Check for all paths from operation `from` to operation `untilOp` for the
554   // given memory effect.
555   std::function<void(Operation *, Operation *)> recur =
556       [&](Operation *from, Operation *untilOp) {
557         assert(
558             from->getParentRegion()->isAncestor(untilOp->getParentRegion()) &&
559             "Checking for side effect between two operations without a common "
560             "ancestor");
561 
562         // If the operations are in different regions, recursively consider all
563         // path from `from` to the parent of `to` and all paths from the parent
564         // of `to` to `to`.
565         if (from->getParentRegion() != untilOp->getParentRegion()) {
566           recur(from, untilOp->getParentOp());
567           until(untilOp->getParentOp(), untilOp);
568           return;
569         }
570 
571         // Now, assuming that `from` and `to` exist in the same region, perform
572         // a CFG traversal to check all the relevant operations.
573 
574         // Additional blocks to consider.
575         SmallVector<Block *, 2> todoBlocks;
576         {
577           // First consider the parent block of `from` an check all operations
578           // after `from`.
579           for (auto iter = ++from->getIterator(), end = from->getBlock()->end();
580                iter != end && &*iter != untilOp; ++iter) {
581             checkOperation(&*iter);
582           }
583 
584           // If the parent of `from` doesn't contain `to`, add the successors
585           // to the list of blocks to check.
586           if (untilOp->getBlock() != from->getBlock())
587             for (Block *succ : from->getBlock()->getSuccessors())
588               todoBlocks.push_back(succ);
589         }
590 
591         SmallPtrSet<Block *, 4> done;
592         // Traverse the CFG until hitting `to`.
593         while (!todoBlocks.empty()) {
594           Block *blk = todoBlocks.pop_back_val();
595           if (done.count(blk))
596             continue;
597           done.insert(blk);
598           for (auto &op : *blk) {
599             if (&op == untilOp)
600               break;
601             checkOperation(&op);
602             if (&op == blk->getTerminator())
603               for (Block *succ : blk->getSuccessors())
604                 todoBlocks.push_back(succ);
605           }
606         }
607       };
608   recur(start, memOp);
609   return !hasSideEffect;
610 }
611 
612 /// Attempt to eliminate loadOp by replacing it with a value stored into memory
613 /// which the load is guaranteed to retrieve. This check involves three
614 /// components: 1) The store and load must be on the same location 2) The store
615 /// must dominate (and therefore must always occur prior to) the load 3) No
616 /// other operations will overwrite the memory loaded between the given load
617 /// and store.  If such a value exists, the replaced `loadOp` will be added to
618 /// `loadOpsToErase` and its memref will be added to `memrefsToErase`.
619 static LogicalResult forwardStoreToLoad(
620     AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase,
621     SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo) {
622 
623   // The store op candidate for forwarding that satisfies all conditions
624   // to replace the load, if any.
625   Operation *lastWriteStoreOp = nullptr;
626 
627   for (auto *user : loadOp.getMemRef().getUsers()) {
628     auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
629     if (!storeOp)
630       continue;
631     MemRefAccess srcAccess(storeOp);
632     MemRefAccess destAccess(loadOp);
633 
634     // 1. Check if the store and the load have mathematically equivalent
635     // affine access functions; this implies that they statically refer to the
636     // same single memref element. As an example this filters out cases like:
637     //     store %A[%i0 + 1]
638     //     load %A[%i0]
639     //     store %A[%M]
640     //     load %A[%N]
641     // Use the AffineValueMap difference based memref access equality checking.
642     if (srcAccess != destAccess)
643       continue;
644 
645     // 2. The store has to dominate the load op to be candidate.
646     if (!domInfo.dominates(storeOp, loadOp))
647       continue;
648 
649     // 3. Ensure there is no intermediate operation which could replace the
650     // value in memory.
651     if (!hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp))
652       continue;
653 
654     // We now have a candidate for forwarding.
655     assert(lastWriteStoreOp == nullptr &&
656            "multiple simulataneous replacement stores");
657     lastWriteStoreOp = storeOp;
658   }
659 
660   if (!lastWriteStoreOp)
661     return failure();
662 
663   // Perform the actual store to load forwarding.
664   Value storeVal =
665       cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
666   // Check if 2 values have the same shape. This is needed for affine vector
667   // loads and stores.
668   if (storeVal.getType() != loadOp.getValue().getType())
669     return failure();
670   loadOp.getValue().replaceAllUsesWith(storeVal);
671   // Record the memref for a later sweep to optimize away.
672   memrefsToErase.insert(loadOp.getMemRef());
673   // Record this to erase later.
674   loadOpsToErase.push_back(loadOp);
675   return success();
676 }
677 
678 // This attempts to find stores which have no impact on the final result.
679 // A writing op writeA will be eliminated if there exists an op writeB if
680 // 1) writeA and writeB have mathematically equivalent affine access functions.
681 // 2) writeB postdominates writeA.
682 // 3) There is no potential read between writeA and writeB.
683 static void findUnusedStore(AffineWriteOpInterface writeA,
684                             SmallVectorImpl<Operation *> &opsToErase,
685                             SmallPtrSetImpl<Value> &memrefsToErase,
686                             PostDominanceInfo &postDominanceInfo) {
687 
688   for (Operation *user : writeA.getMemRef().getUsers()) {
689     // Only consider writing operations.
690     auto writeB = dyn_cast<AffineWriteOpInterface>(user);
691     if (!writeB)
692       continue;
693 
694     // The operations must be distinct.
695     if (writeB == writeA)
696       continue;
697 
698     // Both operations must lie in the same region.
699     if (writeB->getParentRegion() != writeA->getParentRegion())
700       continue;
701 
702     // Both operations must write to the same memory.
703     MemRefAccess srcAccess(writeB);
704     MemRefAccess destAccess(writeA);
705 
706     if (srcAccess != destAccess)
707       continue;
708 
709     // writeB must postdominate writeA.
710     if (!postDominanceInfo.postDominates(writeB, writeA))
711       continue;
712 
713     // There cannot be an operation which reads from memory between
714     // the two writes.
715     if (!hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB))
716       continue;
717 
718     opsToErase.push_back(writeA);
719     break;
720   }
721 }
722 
723 // The load to load forwarding / redundant load elimination is similar to the
724 // store to load forwarding.
725 // loadA will be be replaced with loadB if:
726 // 1) loadA and loadB have mathematically equivalent affine access functions.
727 // 2) loadB dominates loadA.
728 // 3) There is no write between loadA and loadB.
729 static void loadCSE(AffineReadOpInterface loadA,
730                     SmallVectorImpl<Operation *> &loadOpsToErase,
731                     DominanceInfo &domInfo) {
732   SmallVector<AffineReadOpInterface, 4> loadCandidates;
733   for (auto *user : loadA.getMemRef().getUsers()) {
734     auto loadB = dyn_cast<AffineReadOpInterface>(user);
735     if (!loadB || loadB == loadA)
736       continue;
737 
738     MemRefAccess srcAccess(loadB);
739     MemRefAccess destAccess(loadA);
740 
741     // 1. The accesses have to be to the same location.
742     if (srcAccess != destAccess) {
743       continue;
744     }
745 
746     // 2. The store has to dominate the load op to be candidate.
747     if (!domInfo.dominates(loadB, loadA))
748       continue;
749 
750     // 3. There is no write between loadA and loadB.
751     if (!hasNoInterveningEffect<MemoryEffects::Write>(loadB.getOperation(),
752                                                       loadA))
753       continue;
754 
755     // Check if two values have the same shape. This is needed for affine vector
756     // loads.
757     if (loadB.getValue().getType() != loadA.getValue().getType())
758       continue;
759 
760     loadCandidates.push_back(loadB);
761   }
762 
763   // Of the legal load candidates, use the one that dominates all others
764   // to minimize the subsequent need to loadCSE
765   Value loadB;
766   for (AffineReadOpInterface option : loadCandidates) {
767     if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) {
768           return depStore == option ||
769                  domInfo.dominates(option.getOperation(),
770                                    depStore.getOperation());
771         })) {
772       loadB = option.getValue();
773       break;
774     }
775   }
776 
777   if (loadB) {
778     loadA.getValue().replaceAllUsesWith(loadB);
779     // Record this to erase later.
780     loadOpsToErase.push_back(loadA);
781   }
782 }
783 
784 // The store to load forwarding and load CSE rely on three conditions:
785 //
786 // 1) store/load providing a replacement value and load being replaced need to
787 // have mathematically equivalent affine access functions (checked after full
788 // composition of load/store operands); this implies that they access the same
789 // single memref element for all iterations of the common surrounding loop,
790 //
791 // 2) the store/load op should dominate the load op,
792 //
793 // 3) no operation that may write to memory read by the load being replaced can
794 // occur after executing the instruction (load or store) providing the
795 // replacement value and before the load being replaced (thus potentially
796 // allowing overwriting the memory read by the load).
797 //
798 // The above conditions are simple to check, sufficient, and powerful for most
799 // cases in practice - they are sufficient, but not necessary --- since they
800 // don't reason about loops that are guaranteed to execute at least once or
801 // multiple sources to forward from.
802 //
803 // TODO: more forwarding can be done when support for
804 // loop/conditional live-out SSA values is available.
805 // TODO: do general dead store elimination for memref's. This pass
806 // currently only eliminates the stores only if no other loads/uses (other
807 // than dealloc) remain.
808 //
809 void mlir::affineScalarReplace(FuncOp f, DominanceInfo &domInfo,
810                                PostDominanceInfo &postDomInfo) {
811   // Load op's whose results were replaced by those forwarded from stores.
812   SmallVector<Operation *, 8> opsToErase;
813 
814   // A list of memref's that are potentially dead / could be eliminated.
815   SmallPtrSet<Value, 4> memrefsToErase;
816 
817   // Walk all load's and perform store to load forwarding.
818   f.walk([&](AffineReadOpInterface loadOp) {
819     if (failed(
820             forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) {
821       loadCSE(loadOp, opsToErase, domInfo);
822     }
823   });
824 
825   // Erase all load op's whose results were replaced with store fwd'ed ones.
826   for (auto *op : opsToErase)
827     op->erase();
828   opsToErase.clear();
829 
830   // Walk all store's and perform unused store elimination
831   f.walk([&](AffineWriteOpInterface storeOp) {
832     findUnusedStore(storeOp, opsToErase, memrefsToErase, postDomInfo);
833   });
834   // Erase all store op's which don't impact the program
835   for (auto *op : opsToErase)
836     op->erase();
837 
838   // Check if the store fwd'ed memrefs are now left with only stores and can
839   // thus be completely deleted. Note: the canonicalize pass should be able
840   // to do this as well, but we'll do it here since we collected these anyway.
841   for (auto memref : memrefsToErase) {
842     // If the memref hasn't been alloc'ed in this function, skip.
843     Operation *defOp = memref.getDefiningOp();
844     if (!defOp || !isa<memref::AllocOp>(defOp))
845       // TODO: if the memref was returned by a 'call' operation, we
846       // could still erase it if the call had no side-effects.
847       continue;
848     if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
849           return !isa<AffineWriteOpInterface, memref::DeallocOp>(ownerOp);
850         }))
851       continue;
852 
853     // Erase all stores, the dealloc, and the alloc on the memref.
854     for (auto *user : llvm::make_early_inc_range(memref.getUsers()))
855       user->erase();
856     defOp->erase();
857   }
858 }
859