1 //===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous transformation utilities for the Affine
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/Utils.h"
15 
16 #include "mlir/Dialect/Affine/Analysis/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/AffineExprVisitor.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/IR/IntegerSet.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 #define DEBUG_TYPE "affine-utils"
28 
29 using namespace mlir;
30 
31 namespace {
32 /// Visit affine expressions recursively and build the sequence of operations
33 /// that correspond to it.  Visitation functions return an Value of the
34 /// expression subtree they visited or `nullptr` on error.
35 class AffineApplyExpander
36     : public AffineExprVisitor<AffineApplyExpander, Value> {
37 public:
38   /// This internal class expects arguments to be non-null, checks must be
39   /// performed at the call site.
40   AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
41                       ValueRange symbolValues, Location loc)
42       : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
43         loc(loc) {}
44 
45   template <typename OpTy>
46   Value buildBinaryExpr(AffineBinaryOpExpr expr) {
47     auto lhs = visit(expr.getLHS());
48     auto rhs = visit(expr.getRHS());
49     if (!lhs || !rhs)
50       return nullptr;
51     auto op = builder.create<OpTy>(loc, lhs, rhs);
52     return op.getResult();
53   }
54 
55   Value visitAddExpr(AffineBinaryOpExpr expr) {
56     return buildBinaryExpr<arith::AddIOp>(expr);
57   }
58 
59   Value visitMulExpr(AffineBinaryOpExpr expr) {
60     return buildBinaryExpr<arith::MulIOp>(expr);
61   }
62 
63   /// Euclidean modulo operation: negative RHS is not allowed.
64   /// Remainder of the euclidean integer division is always non-negative.
65   ///
66   /// Implemented as
67   ///
68   ///     a mod b =
69   ///         let remainder = srem a, b;
70   ///             negative = a < 0 in
71   ///         select negative, remainder + b, remainder.
72   Value visitModExpr(AffineBinaryOpExpr expr) {
73     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
74     if (!rhsConst) {
75       emitError(
76           loc,
77           "semi-affine expressions (modulo by non-const) are not supported");
78       return nullptr;
79     }
80     if (rhsConst.getValue() <= 0) {
81       emitError(loc, "modulo by non-positive value is not supported");
82       return nullptr;
83     }
84 
85     auto lhs = visit(expr.getLHS());
86     auto rhs = visit(expr.getRHS());
87     assert(lhs && rhs && "unexpected affine expr lowering failure");
88 
89     Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
90     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
91     Value isRemainderNegative = builder.create<arith::CmpIOp>(
92         loc, arith::CmpIPredicate::slt, remainder, zeroCst);
93     Value correctedRemainder =
94         builder.create<arith::AddIOp>(loc, remainder, rhs);
95     Value result = builder.create<arith::SelectOp>(
96         loc, isRemainderNegative, correctedRemainder, remainder);
97     return result;
98   }
99 
100   /// Floor division operation (rounds towards negative infinity).
101   ///
102   /// For positive divisors, it can be implemented without branching and with a
103   /// single division operation as
104   ///
105   ///        a floordiv b =
106   ///            let negative = a < 0 in
107   ///            let absolute = negative ? -a - 1 : a in
108   ///            let quotient = absolute / b in
109   ///                negative ? -quotient - 1 : quotient
110   Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
111     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
112     if (!rhsConst) {
113       emitError(
114           loc,
115           "semi-affine expressions (division by non-const) are not supported");
116       return nullptr;
117     }
118     if (rhsConst.getValue() <= 0) {
119       emitError(loc, "division by non-positive value is not supported");
120       return nullptr;
121     }
122 
123     auto lhs = visit(expr.getLHS());
124     auto rhs = visit(expr.getRHS());
125     assert(lhs && rhs && "unexpected affine expr lowering failure");
126 
127     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
128     Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1);
129     Value negative = builder.create<arith::CmpIOp>(
130         loc, arith::CmpIPredicate::slt, lhs, zeroCst);
131     Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
132     Value dividend =
133         builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs);
134     Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
135     Value correctedQuotient =
136         builder.create<arith::SubIOp>(loc, noneCst, quotient);
137     Value result = builder.create<arith::SelectOp>(loc, negative,
138                                                    correctedQuotient, quotient);
139     return result;
140   }
141 
142   /// Ceiling division operation (rounds towards positive infinity).
143   ///
144   /// For positive divisors, it can be implemented without branching and with a
145   /// single division operation as
146   ///
147   ///     a ceildiv b =
148   ///         let negative = a <= 0 in
149   ///         let absolute = negative ? -a : a - 1 in
150   ///         let quotient = absolute / b in
151   ///             negative ? -quotient : quotient + 1
152   Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
153     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
154     if (!rhsConst) {
155       emitError(loc) << "semi-affine expressions (division by non-const) are "
156                         "not supported";
157       return nullptr;
158     }
159     if (rhsConst.getValue() <= 0) {
160       emitError(loc, "division by non-positive value is not supported");
161       return nullptr;
162     }
163     auto lhs = visit(expr.getLHS());
164     auto rhs = visit(expr.getRHS());
165     assert(lhs && rhs && "unexpected affine expr lowering failure");
166 
167     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
168     Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1);
169     Value nonPositive = builder.create<arith::CmpIOp>(
170         loc, arith::CmpIPredicate::sle, lhs, zeroCst);
171     Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
172     Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
173     Value dividend =
174         builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented);
175     Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
176     Value negatedQuotient =
177         builder.create<arith::SubIOp>(loc, zeroCst, quotient);
178     Value incrementedQuotient =
179         builder.create<arith::AddIOp>(loc, quotient, oneCst);
180     Value result = builder.create<arith::SelectOp>(
181         loc, nonPositive, negatedQuotient, incrementedQuotient);
182     return result;
183   }
184 
185   Value visitConstantExpr(AffineConstantExpr expr) {
186     auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
187     return op.getResult();
188   }
189 
190   Value visitDimExpr(AffineDimExpr expr) {
191     assert(expr.getPosition() < dimValues.size() &&
192            "affine dim position out of range");
193     return dimValues[expr.getPosition()];
194   }
195 
196   Value visitSymbolExpr(AffineSymbolExpr expr) {
197     assert(expr.getPosition() < symbolValues.size() &&
198            "symbol dim position out of range");
199     return symbolValues[expr.getPosition()];
200   }
201 
202 private:
203   OpBuilder &builder;
204   ValueRange dimValues;
205   ValueRange symbolValues;
206 
207   Location loc;
208 };
209 } // namespace
210 
211 /// Create a sequence of operations that implement the `expr` applied to the
212 /// given dimension and symbol values.
213 mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc,
214                                    AffineExpr expr, ValueRange dimValues,
215                                    ValueRange symbolValues) {
216   return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
217 }
218 
219 /// Create a sequence of operations that implement the `affineMap` applied to
220 /// the given `operands` (as it it were an AffineApplyOp).
221 Optional<SmallVector<Value, 8>> mlir::expandAffineMap(OpBuilder &builder,
222                                                       Location loc,
223                                                       AffineMap affineMap,
224                                                       ValueRange operands) {
225   auto numDims = affineMap.getNumDims();
226   auto expanded = llvm::to_vector<8>(
227       llvm::map_range(affineMap.getResults(),
228                       [numDims, &builder, loc, operands](AffineExpr expr) {
229                         return expandAffineExpr(builder, loc, expr,
230                                                 operands.take_front(numDims),
231                                                 operands.drop_front(numDims));
232                       }));
233   if (llvm::all_of(expanded, [](Value v) { return v; }))
234     return expanded;
235   return None;
236 }
237 
238 /// Promotes the `then` or the `else` block of `ifOp` (depending on whether
239 /// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
240 /// the rest of the op.
241 static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
242   if (elseBlock)
243     assert(ifOp.hasElse() && "else block expected");
244 
245   Block *destBlock = ifOp->getBlock();
246   Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
247   destBlock->getOperations().splice(
248       Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(),
249       std::prev(srcBlock->end()));
250   ifOp.erase();
251 }
252 
253 /// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
254 /// on. The `ifOp` could be hoisted and placed right before such an operation.
255 /// This method assumes that the ifOp has been canonicalized (to be correct and
256 /// effective).
257 static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
258   // Walk up the parents past all for op that this conditional is invariant on.
259   auto ifOperands = ifOp.getOperands();
260   auto *res = ifOp.getOperation();
261   while (!isa<FuncOp>(res->getParentOp())) {
262     auto *parentOp = res->getParentOp();
263     if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
264       if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
265         break;
266     } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
267       for (auto iv : parallelOp.getIVs())
268         if (llvm::is_contained(ifOperands, iv))
269           break;
270     } else if (!isa<AffineIfOp>(parentOp)) {
271       // Won't walk up past anything other than affine.for/if ops.
272       break;
273     }
274     // You can always hoist up past any affine.if ops.
275     res = parentOp;
276   }
277   return res;
278 }
279 
280 /// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
281 /// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
282 /// otherwise the same `ifOp`.
283 static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
284   // No hoisting to do.
285   if (hoistOverOp == ifOp)
286     return ifOp;
287 
288   // Create the hoisted 'if' first. Then, clone the op we are hoisting over for
289   // the else block. Then drop the else block of the original 'if' in the 'then'
290   // branch while promoting its then block, and analogously drop the 'then'
291   // block of the original 'if' from the 'else' branch while promoting its else
292   // block.
293   BlockAndValueMapping operandMap;
294   OpBuilder b(hoistOverOp);
295   auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
296                                           ifOp.getOperands(),
297                                           /*elseBlock=*/true);
298 
299   // Create a clone of hoistOverOp to use for the else branch of the hoisted
300   // conditional. The else block may get optimized away if empty.
301   Operation *hoistOverOpClone = nullptr;
302   // We use this unique name to identify/find  `ifOp`'s clone in the else
303   // version.
304   StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting");
305   operandMap.clear();
306   b.setInsertionPointAfter(hoistOverOp);
307   // We'll set an attribute to identify this op in a clone of this sub-tree.
308   ifOp->setAttr(idForIfOp, b.getBoolAttr(true));
309   hoistOverOpClone = b.clone(*hoistOverOp, operandMap);
310 
311   // Promote the 'then' block of the original affine.if in the then version.
312   promoteIfBlock(ifOp, /*elseBlock=*/false);
313 
314   // Move the then version to the hoisted if op's 'then' block.
315   auto *thenBlock = hoistedIfOp.getThenBlock();
316   thenBlock->getOperations().splice(thenBlock->begin(),
317                                     hoistOverOp->getBlock()->getOperations(),
318                                     Block::iterator(hoistOverOp));
319 
320   // Find the clone of the original affine.if op in the else version.
321   AffineIfOp ifCloneInElse;
322   hoistOverOpClone->walk([&](AffineIfOp ifClone) {
323     if (!ifClone->getAttr(idForIfOp))
324       return WalkResult::advance();
325     ifCloneInElse = ifClone;
326     return WalkResult::interrupt();
327   });
328   assert(ifCloneInElse && "if op clone should exist");
329   // For the else block, promote the else block of the original 'if' if it had
330   // one; otherwise, the op itself is to be erased.
331   if (!ifCloneInElse.hasElse())
332     ifCloneInElse.erase();
333   else
334     promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
335 
336   // Move the else version into the else block of the hoisted if op.
337   auto *elseBlock = hoistedIfOp.getElseBlock();
338   elseBlock->getOperations().splice(
339       elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
340       Block::iterator(hoistOverOpClone));
341 
342   return hoistedIfOp;
343 }
344 
345 LogicalResult
346 mlir::affineParallelize(AffineForOp forOp,
347                         ArrayRef<LoopReduction> parallelReductions) {
348   // Fail early if there are iter arguments that are not reductions.
349   unsigned numReductions = parallelReductions.size();
350   if (numReductions != forOp.getNumIterOperands())
351     return failure();
352 
353   Location loc = forOp.getLoc();
354   OpBuilder outsideBuilder(forOp);
355   AffineMap lowerBoundMap = forOp.getLowerBoundMap();
356   ValueRange lowerBoundOperands = forOp.getLowerBoundOperands();
357   AffineMap upperBoundMap = forOp.getUpperBoundMap();
358   ValueRange upperBoundOperands = forOp.getUpperBoundOperands();
359 
360   // Creating empty 1-D affine.parallel op.
361   auto reducedValues = llvm::to_vector<4>(llvm::map_range(
362       parallelReductions, [](const LoopReduction &red) { return red.value; }));
363   auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
364       parallelReductions, [](const LoopReduction &red) { return red.kind; }));
365   AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
366       loc, ValueRange(reducedValues).getTypes(), reductionKinds,
367       llvm::makeArrayRef(lowerBoundMap), lowerBoundOperands,
368       llvm::makeArrayRef(upperBoundMap), upperBoundOperands,
369       llvm::makeArrayRef(forOp.getStep()));
370   // Steal the body of the old affine for op.
371   newPloop.region().takeBody(forOp.region());
372   Operation *yieldOp = &newPloop.getBody()->back();
373 
374   // Handle the initial values of reductions because the parallel loop always
375   // starts from the neutral value.
376   SmallVector<Value> newResults;
377   newResults.reserve(numReductions);
378   for (unsigned i = 0; i < numReductions; ++i) {
379     Value init = forOp.getIterOperands()[i];
380     // This works because we are only handling single-op reductions at the
381     // moment. A switch on reduction kind or a mechanism to collect operations
382     // participating in the reduction will be necessary for multi-op reductions.
383     Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp();
384     assert(reductionOp && "yielded value is expected to be produced by an op");
385     outsideBuilder.getInsertionBlock()->getOperations().splice(
386         outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(),
387         reductionOp);
388     reductionOp->setOperands({init, newPloop->getResult(i)});
389     forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0));
390   }
391 
392   // Update the loop terminator to yield reduced values bypassing the reduction
393   // operation itself (now moved outside of the loop) and erase the block
394   // arguments that correspond to reductions. Note that the loop always has one
395   // "main" induction variable whenc coming from a non-parallel for.
396   unsigned numIVs = 1;
397   yieldOp->setOperands(reducedValues);
398   newPloop.getBody()->eraseArguments(
399       llvm::to_vector<4>(llvm::seq<unsigned>(numIVs, numReductions + numIVs)));
400 
401   forOp.erase();
402   return success();
403 }
404 
405 // Returns success if any hoisting happened.
406 LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
407   // Bail out early if the ifOp returns a result.  TODO: Consider how to
408   // properly support this case.
409   if (ifOp.getNumResults() != 0)
410     return failure();
411 
412   // Apply canonicalization patterns and folding - this is necessary for the
413   // hoisting check to be correct (operands should be composed), and to be more
414   // effective (no unused operands). Since the pattern rewriter's folding is
415   // entangled with application of patterns, we may fold/end up erasing the op,
416   // in which case we return with `folded` being set.
417   RewritePatternSet patterns(ifOp.getContext());
418   AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
419   bool erased;
420   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
421   (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased);
422   if (erased) {
423     if (folded)
424       *folded = true;
425     return failure();
426   }
427   if (folded)
428     *folded = false;
429 
430   // The folding above should have ensured this, but the affine.if's
431   // canonicalization is missing composition of affine.applys into it.
432   assert(llvm::all_of(ifOp.getOperands(),
433                       [](Value v) {
434                         return isTopLevelValue(v) || isForInductionVar(v);
435                       }) &&
436          "operands not composed");
437 
438   // We are going hoist as high as possible.
439   // TODO: this could be customized in the future.
440   auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
441 
442   AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
443   // Nothing to hoist over.
444   if (hoistedIfOp == ifOp)
445     return failure();
446 
447   // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
448   // a sequence of affine.fors that are all perfectly nested).
449   (void)applyPatternsAndFoldGreedily(
450       hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
451       frozenPatterns);
452 
453   return success();
454 }
455 
456 // Return the min expr after replacing the given dim.
457 AffineExpr mlir::substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
458                               AffineExpr max, bool positivePath) {
459   if (e == dim)
460     return positivePath ? min : max;
461   if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) {
462     AffineExpr lhs = bin.getLHS();
463     AffineExpr rhs = bin.getRHS();
464     if (bin.getKind() == mlir::AffineExprKind::Add)
465       return substWithMin(lhs, dim, min, max, positivePath) +
466              substWithMin(rhs, dim, min, max, positivePath);
467 
468     auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>();
469     auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>();
470     if (c1 && c1.getValue() < 0)
471       return getAffineBinaryOpExpr(
472           bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath));
473     if (c2 && c2.getValue() < 0)
474       return getAffineBinaryOpExpr(
475           bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2);
476     return getAffineBinaryOpExpr(
477         bin.getKind(), substWithMin(lhs, dim, min, max, positivePath),
478         substWithMin(rhs, dim, min, max, positivePath));
479   }
480   return e;
481 }
482 
483 void mlir::normalizeAffineParallel(AffineParallelOp op) {
484   // Loops with min/max in bounds are not normalized at the moment.
485   if (op.hasMinMaxBounds())
486     return;
487 
488   AffineMap lbMap = op.lowerBoundsMap();
489   SmallVector<int64_t, 8> steps = op.getSteps();
490   // No need to do any work if the parallel op is already normalized.
491   bool isAlreadyNormalized =
492       llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
493         int64_t step = std::get<0>(tuple);
494         auto lbExpr =
495             std::get<1>(tuple).template dyn_cast<AffineConstantExpr>();
496         return lbExpr && lbExpr.getValue() == 0 && step == 1;
497       });
498   if (isAlreadyNormalized)
499     return;
500 
501   AffineValueMap ranges;
502   AffineValueMap::difference(op.getUpperBoundsValueMap(),
503                              op.getLowerBoundsValueMap(), &ranges);
504   auto builder = OpBuilder::atBlockBegin(op.getBody());
505   auto zeroExpr = builder.getAffineConstantExpr(0);
506   SmallVector<AffineExpr, 8> lbExprs;
507   SmallVector<AffineExpr, 8> ubExprs;
508   for (unsigned i = 0, e = steps.size(); i < e; ++i) {
509     int64_t step = steps[i];
510 
511     // Adjust the lower bound to be 0.
512     lbExprs.push_back(zeroExpr);
513 
514     // Adjust the upper bound expression: 'range / step'.
515     AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step);
516     ubExprs.push_back(ubExpr);
517 
518     // Adjust the corresponding IV: 'lb + i * step'.
519     BlockArgument iv = op.getBody()->getArgument(i);
520     AffineExpr lbExpr = lbMap.getResult(i);
521     unsigned nDims = lbMap.getNumDims();
522     auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step;
523     auto map = AffineMap::get(/*dimCount=*/nDims + 1,
524                               /*symbolCount=*/lbMap.getNumSymbols(), expr);
525 
526     // Use an 'affine.apply' op that will be simplified later in subsequent
527     // canonicalizations.
528     OperandRange lbOperands = op.getLowerBoundsOperands();
529     OperandRange dimOperands = lbOperands.take_front(nDims);
530     OperandRange symbolOperands = lbOperands.drop_front(nDims);
531     SmallVector<Value, 8> applyOperands{dimOperands};
532     applyOperands.push_back(iv);
533     applyOperands.append(symbolOperands.begin(), symbolOperands.end());
534     auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
535     iv.replaceAllUsesExcept(apply, apply);
536   }
537 
538   SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
539   op.setSteps(newSteps);
540   auto newLowerMap = AffineMap::get(
541       /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext());
542   op.setLowerBounds({}, newLowerMap);
543   auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(),
544                                     ubExprs, op.getContext());
545   op.setUpperBounds(ranges.getOperands(), newUpperMap);
546 }
547 
548 /// Normalizes affine.for ops. If the affine.for op has only a single iteration
549 /// only then it is simply promoted, else it is normalized in the traditional
550 /// way, by converting the lower bound to zero and loop step to one. The upper
551 /// bound is set to the trip count of the loop. For now, original loops must
552 /// have lower bound with a single result only. There is no such restriction on
553 /// upper bounds.
554 LogicalResult mlir::normalizeAffineFor(AffineForOp op) {
555   if (succeeded(promoteIfSingleIteration(op)))
556     return success();
557 
558   // Check if the forop is already normalized.
559   if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) &&
560       (op.getStep() == 1))
561     return success();
562 
563   // Check if the lower bound has a single result only. Loops with a max lower
564   // bound can't be normalized without additional support like
565   // affine.execute_region's. If the lower bound does not have a single result
566   // then skip this op.
567   if (op.getLowerBoundMap().getNumResults() != 1)
568     return failure();
569 
570   Location loc = op.getLoc();
571   OpBuilder opBuilder(op);
572   int64_t origLoopStep = op.getStep();
573 
574   // Calculate upperBound for normalized loop.
575   SmallVector<Value, 4> ubOperands;
576   AffineBound lb = op.getLowerBound();
577   AffineBound ub = op.getUpperBound();
578   ubOperands.reserve(ub.getNumOperands() + lb.getNumOperands());
579   AffineMap origLbMap = lb.getMap();
580   AffineMap origUbMap = ub.getMap();
581 
582   // Add dimension operands from upper/lower bound.
583   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
584     ubOperands.push_back(ub.getOperand(j));
585   for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j)
586     ubOperands.push_back(lb.getOperand(j));
587 
588   // Add symbol operands from upper/lower bound.
589   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
590     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
591   for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
592     ubOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
593 
594   // Add original result expressions from lower/upper bound map.
595   SmallVector<AffineExpr, 1> origLbExprs(origLbMap.getResults().begin(),
596                                          origLbMap.getResults().end());
597   SmallVector<AffineExpr, 2> origUbExprs(origUbMap.getResults().begin(),
598                                          origUbMap.getResults().end());
599   SmallVector<AffineExpr, 4> newUbExprs;
600 
601   // The original upperBound can have more than one result. For the new
602   // upperBound of this loop, take difference of all possible combinations of
603   // the ub results and lb result and ceildiv with the loop step. For e.g.,
604   //
605   //  affine.for %i1 = 0 to min affine_map<(d0)[] -> (d0 + 32, 1024)>(%i0)
606   //  will have an upperBound map as,
607   //  affine_map<(d0)[] -> (((d0 + 32) - 0) ceildiv 1, (1024 - 0) ceildiv
608   //  1)>(%i0)
609   //
610   // Insert all combinations of upper/lower bound results.
611   for (unsigned i = 0, e = origUbExprs.size(); i < e; ++i) {
612     newUbExprs.push_back(
613         (origUbExprs[i] - origLbExprs[0]).ceilDiv(origLoopStep));
614   }
615 
616   // Construct newUbMap.
617   AffineMap newUbMap =
618       AffineMap::get(origLbMap.getNumDims() + origUbMap.getNumDims(),
619                      origLbMap.getNumSymbols() + origUbMap.getNumSymbols(),
620                      newUbExprs, opBuilder.getContext());
621   canonicalizeMapAndOperands(&newUbMap, &ubOperands);
622 
623   // Normalize the loop.
624   op.setUpperBound(ubOperands, newUbMap);
625   op.setLowerBound({}, opBuilder.getConstantAffineMap(0));
626   op.setStep(1);
627 
628   // Calculate the Value of new loopIV. Create affine.apply for the value of
629   // the loopIV in normalized loop.
630   opBuilder.setInsertionPointToStart(op.getBody());
631   SmallVector<Value, 4> lbOperands(lb.getOperands().begin(),
632                                    lb.getOperands().begin() +
633                                        lb.getMap().getNumDims());
634   // Add an extra dim operand for loopIV.
635   lbOperands.push_back(op.getInductionVar());
636   // Add symbol operands from lower bound.
637   for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
638     lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
639 
640   AffineExpr origIVExpr = opBuilder.getAffineDimExpr(lb.getMap().getNumDims());
641   AffineExpr newIVExpr = origIVExpr * origLoopStep + origLbMap.getResult(0);
642   AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1,
643                                    origLbMap.getNumSymbols(), newIVExpr);
644   canonicalizeMapAndOperands(&ivMap, &lbOperands);
645   Operation *newIV = opBuilder.create<AffineApplyOp>(loc, ivMap, lbOperands);
646   op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
647   return success();
648 }
649 
650 /// Ensure that all operations that could be executed after `start`
651 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
652 /// between the operations) do not have the potential memory effect
653 /// `EffectType` on `memOp`. `memOp`  is an operation that reads or writes to
654 /// a memref. For example, if `EffectType` is MemoryEffects::Write, this method
655 /// will check if there is no write to the memory between `start` and `memOp`
656 /// that would change the read within `memOp`.
657 template <typename EffectType, typename T>
658 static bool hasNoInterveningEffect(Operation *start, T memOp) {
659   Value memref = memOp.getMemRef();
660   bool isOriginalAllocation = memref.getDefiningOp<memref::AllocaOp>() ||
661                               memref.getDefiningOp<memref::AllocOp>();
662 
663   // A boolean representing whether an intervening operation could have impacted
664   // memOp.
665   bool hasSideEffect = false;
666 
667   // Check whether the effect on memOp can be caused by a given operation op.
668   std::function<void(Operation *)> checkOperation = [&](Operation *op) {
669     // If the effect has alreay been found, early exit,
670     if (hasSideEffect)
671       return;
672 
673     if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
674       SmallVector<MemoryEffects::EffectInstance, 1> effects;
675       memEffect.getEffects(effects);
676 
677       bool opMayHaveEffect = false;
678       for (auto effect : effects) {
679         // If op causes EffectType on a potentially aliasing location for
680         // memOp, mark as having the effect.
681         if (isa<EffectType>(effect.getEffect())) {
682           if (isOriginalAllocation && effect.getValue() &&
683               (effect.getValue().getDefiningOp<memref::AllocaOp>() ||
684                effect.getValue().getDefiningOp<memref::AllocOp>())) {
685             if (effect.getValue() != memref)
686               continue;
687           }
688           opMayHaveEffect = true;
689           break;
690         }
691       }
692 
693       if (!opMayHaveEffect)
694         return;
695 
696       // If the side effect comes from an affine read or write, try to
697       // prove the side effecting `op` cannot reach `memOp`.
698       if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
699         MemRefAccess srcAccess(op);
700         MemRefAccess destAccess(memOp);
701         // Dependence analysis is only correct if both ops operate on the same
702         // memref.
703         if (srcAccess.memref == destAccess.memref) {
704           FlatAffineValueConstraints dependenceConstraints;
705 
706           // Number of loops containing the start op and the ending operation.
707           unsigned minSurroundingLoops =
708               getNumCommonSurroundingLoops(*start, *memOp);
709 
710           // Number of loops containing the operation `op` which has the
711           // potential memory side effect and can occur on a path between
712           // `start` and `memOp`.
713           unsigned nsLoops = getNumCommonSurroundingLoops(*op, *memOp);
714 
715           // For ease, let's consider the case that `op` is a store and we're
716           // looking for other potential stores (e.g `op`) that overwrite memory
717           // after `start`, and before being read in `memOp`. In this case, we
718           // only need to consider other potential stores with depth >
719           // minSurrounding loops since `start` would overwrite any store with a
720           // smaller number of surrounding loops before.
721           unsigned d;
722           for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
723             DependenceResult result = checkMemrefAccessDependence(
724                 srcAccess, destAccess, d, &dependenceConstraints,
725                 /*dependenceComponents=*/nullptr);
726             if (hasDependence(result)) {
727               hasSideEffect = true;
728               return;
729             }
730           }
731 
732           // No side effect was seen, simply return.
733           return;
734         }
735       }
736       hasSideEffect = true;
737       return;
738     }
739 
740     if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
741       // Recurse into the regions for this op and check whether the internal
742       // operations may have the side effect `EffectType` on memOp.
743       for (Region &region : op->getRegions())
744         for (Block &block : region)
745           for (Operation &op : block)
746             checkOperation(&op);
747       return;
748     }
749 
750     // Otherwise, conservatively assume generic operations have the effect
751     // on the operation
752     hasSideEffect = true;
753   };
754 
755   // Check all paths from ancestor op `parent` to the operation `to` for the
756   // effect. It is known that `to` must be contained within `parent`.
757   auto until = [&](Operation *parent, Operation *to) {
758     // TODO check only the paths from `parent` to `to`.
759     // Currently we fallback and check the entire parent op, rather than
760     // just the paths from the parent path, stopping after reaching `to`.
761     // This is conservatively correct, but could be made more aggressive.
762     assert(parent->isAncestor(to));
763     checkOperation(parent);
764   };
765 
766   // Check for all paths from operation `from` to operation `untilOp` for the
767   // given memory effect.
768   std::function<void(Operation *, Operation *)> recur =
769       [&](Operation *from, Operation *untilOp) {
770         assert(
771             from->getParentRegion()->isAncestor(untilOp->getParentRegion()) &&
772             "Checking for side effect between two operations without a common "
773             "ancestor");
774 
775         // If the operations are in different regions, recursively consider all
776         // path from `from` to the parent of `to` and all paths from the parent
777         // of `to` to `to`.
778         if (from->getParentRegion() != untilOp->getParentRegion()) {
779           recur(from, untilOp->getParentOp());
780           until(untilOp->getParentOp(), untilOp);
781           return;
782         }
783 
784         // Now, assuming that `from` and `to` exist in the same region, perform
785         // a CFG traversal to check all the relevant operations.
786 
787         // Additional blocks to consider.
788         SmallVector<Block *, 2> todoBlocks;
789         {
790           // First consider the parent block of `from` an check all operations
791           // after `from`.
792           for (auto iter = ++from->getIterator(), end = from->getBlock()->end();
793                iter != end && &*iter != untilOp; ++iter) {
794             checkOperation(&*iter);
795           }
796 
797           // If the parent of `from` doesn't contain `to`, add the successors
798           // to the list of blocks to check.
799           if (untilOp->getBlock() != from->getBlock())
800             for (Block *succ : from->getBlock()->getSuccessors())
801               todoBlocks.push_back(succ);
802         }
803 
804         SmallPtrSet<Block *, 4> done;
805         // Traverse the CFG until hitting `to`.
806         while (!todoBlocks.empty()) {
807           Block *blk = todoBlocks.pop_back_val();
808           if (done.count(blk))
809             continue;
810           done.insert(blk);
811           for (auto &op : *blk) {
812             if (&op == untilOp)
813               break;
814             checkOperation(&op);
815             if (&op == blk->getTerminator())
816               for (Block *succ : blk->getSuccessors())
817                 todoBlocks.push_back(succ);
818           }
819         }
820       };
821   recur(start, memOp);
822   return !hasSideEffect;
823 }
824 
825 /// Attempt to eliminate loadOp by replacing it with a value stored into memory
826 /// which the load is guaranteed to retrieve. This check involves three
827 /// components: 1) The store and load must be on the same location 2) The store
828 /// must dominate (and therefore must always occur prior to) the load 3) No
829 /// other operations will overwrite the memory loaded between the given load
830 /// and store.  If such a value exists, the replaced `loadOp` will be added to
831 /// `loadOpsToErase` and its memref will be added to `memrefsToErase`.
832 static LogicalResult forwardStoreToLoad(
833     AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase,
834     SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo) {
835 
836   // The store op candidate for forwarding that satisfies all conditions
837   // to replace the load, if any.
838   Operation *lastWriteStoreOp = nullptr;
839 
840   for (auto *user : loadOp.getMemRef().getUsers()) {
841     auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
842     if (!storeOp)
843       continue;
844     MemRefAccess srcAccess(storeOp);
845     MemRefAccess destAccess(loadOp);
846 
847     // 1. Check if the store and the load have mathematically equivalent
848     // affine access functions; this implies that they statically refer to the
849     // same single memref element. As an example this filters out cases like:
850     //     store %A[%i0 + 1]
851     //     load %A[%i0]
852     //     store %A[%M]
853     //     load %A[%N]
854     // Use the AffineValueMap difference based memref access equality checking.
855     if (srcAccess != destAccess)
856       continue;
857 
858     // 2. The store has to dominate the load op to be candidate.
859     if (!domInfo.dominates(storeOp, loadOp))
860       continue;
861 
862     // 3. Ensure there is no intermediate operation which could replace the
863     // value in memory.
864     if (!hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp))
865       continue;
866 
867     // We now have a candidate for forwarding.
868     assert(lastWriteStoreOp == nullptr &&
869            "multiple simulataneous replacement stores");
870     lastWriteStoreOp = storeOp;
871   }
872 
873   if (!lastWriteStoreOp)
874     return failure();
875 
876   // Perform the actual store to load forwarding.
877   Value storeVal =
878       cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
879   // Check if 2 values have the same shape. This is needed for affine vector
880   // loads and stores.
881   if (storeVal.getType() != loadOp.getValue().getType())
882     return failure();
883   loadOp.getValue().replaceAllUsesWith(storeVal);
884   // Record the memref for a later sweep to optimize away.
885   memrefsToErase.insert(loadOp.getMemRef());
886   // Record this to erase later.
887   loadOpsToErase.push_back(loadOp);
888   return success();
889 }
890 
891 // This attempts to find stores which have no impact on the final result.
892 // A writing op writeA will be eliminated if there exists an op writeB if
893 // 1) writeA and writeB have mathematically equivalent affine access functions.
894 // 2) writeB postdominates writeA.
895 // 3) There is no potential read between writeA and writeB.
896 static void findUnusedStore(AffineWriteOpInterface writeA,
897                             SmallVectorImpl<Operation *> &opsToErase,
898                             PostDominanceInfo &postDominanceInfo) {
899 
900   for (Operation *user : writeA.getMemRef().getUsers()) {
901     // Only consider writing operations.
902     auto writeB = dyn_cast<AffineWriteOpInterface>(user);
903     if (!writeB)
904       continue;
905 
906     // The operations must be distinct.
907     if (writeB == writeA)
908       continue;
909 
910     // Both operations must lie in the same region.
911     if (writeB->getParentRegion() != writeA->getParentRegion())
912       continue;
913 
914     // Both operations must write to the same memory.
915     MemRefAccess srcAccess(writeB);
916     MemRefAccess destAccess(writeA);
917 
918     if (srcAccess != destAccess)
919       continue;
920 
921     // writeB must postdominate writeA.
922     if (!postDominanceInfo.postDominates(writeB, writeA))
923       continue;
924 
925     // There cannot be an operation which reads from memory between
926     // the two writes.
927     if (!hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB))
928       continue;
929 
930     opsToErase.push_back(writeA);
931     break;
932   }
933 }
934 
935 // The load to load forwarding / redundant load elimination is similar to the
936 // store to load forwarding.
937 // loadA will be be replaced with loadB if:
938 // 1) loadA and loadB have mathematically equivalent affine access functions.
939 // 2) loadB dominates loadA.
940 // 3) There is no write between loadA and loadB.
941 static void loadCSE(AffineReadOpInterface loadA,
942                     SmallVectorImpl<Operation *> &loadOpsToErase,
943                     DominanceInfo &domInfo) {
944   SmallVector<AffineReadOpInterface, 4> loadCandidates;
945   for (auto *user : loadA.getMemRef().getUsers()) {
946     auto loadB = dyn_cast<AffineReadOpInterface>(user);
947     if (!loadB || loadB == loadA)
948       continue;
949 
950     MemRefAccess srcAccess(loadB);
951     MemRefAccess destAccess(loadA);
952 
953     // 1. The accesses have to be to the same location.
954     if (srcAccess != destAccess) {
955       continue;
956     }
957 
958     // 2. The store has to dominate the load op to be candidate.
959     if (!domInfo.dominates(loadB, loadA))
960       continue;
961 
962     // 3. There is no write between loadA and loadB.
963     if (!hasNoInterveningEffect<MemoryEffects::Write>(loadB.getOperation(),
964                                                       loadA))
965       continue;
966 
967     // Check if two values have the same shape. This is needed for affine vector
968     // loads.
969     if (loadB.getValue().getType() != loadA.getValue().getType())
970       continue;
971 
972     loadCandidates.push_back(loadB);
973   }
974 
975   // Of the legal load candidates, use the one that dominates all others
976   // to minimize the subsequent need to loadCSE
977   Value loadB;
978   for (AffineReadOpInterface option : loadCandidates) {
979     if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) {
980           return depStore == option ||
981                  domInfo.dominates(option.getOperation(),
982                                    depStore.getOperation());
983         })) {
984       loadB = option.getValue();
985       break;
986     }
987   }
988 
989   if (loadB) {
990     loadA.getValue().replaceAllUsesWith(loadB);
991     // Record this to erase later.
992     loadOpsToErase.push_back(loadA);
993   }
994 }
995 
996 // The store to load forwarding and load CSE rely on three conditions:
997 //
998 // 1) store/load providing a replacement value and load being replaced need to
999 // have mathematically equivalent affine access functions (checked after full
1000 // composition of load/store operands); this implies that they access the same
1001 // single memref element for all iterations of the common surrounding loop,
1002 //
1003 // 2) the store/load op should dominate the load op,
1004 //
1005 // 3) no operation that may write to memory read by the load being replaced can
1006 // occur after executing the instruction (load or store) providing the
1007 // replacement value and before the load being replaced (thus potentially
1008 // allowing overwriting the memory read by the load).
1009 //
1010 // The above conditions are simple to check, sufficient, and powerful for most
1011 // cases in practice - they are sufficient, but not necessary --- since they
1012 // don't reason about loops that are guaranteed to execute at least once or
1013 // multiple sources to forward from.
1014 //
1015 // TODO: more forwarding can be done when support for
1016 // loop/conditional live-out SSA values is available.
1017 // TODO: do general dead store elimination for memref's. This pass
1018 // currently only eliminates the stores only if no other loads/uses (other
1019 // than dealloc) remain.
1020 //
1021 void mlir::affineScalarReplace(FuncOp f, DominanceInfo &domInfo,
1022                                PostDominanceInfo &postDomInfo) {
1023   // Load op's whose results were replaced by those forwarded from stores.
1024   SmallVector<Operation *, 8> opsToErase;
1025 
1026   // A list of memref's that are potentially dead / could be eliminated.
1027   SmallPtrSet<Value, 4> memrefsToErase;
1028 
1029   // Walk all load's and perform store to load forwarding.
1030   f.walk([&](AffineReadOpInterface loadOp) {
1031     if (failed(
1032             forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) {
1033       loadCSE(loadOp, opsToErase, domInfo);
1034     }
1035   });
1036 
1037   // Erase all load op's whose results were replaced with store fwd'ed ones.
1038   for (auto *op : opsToErase)
1039     op->erase();
1040   opsToErase.clear();
1041 
1042   // Walk all store's and perform unused store elimination
1043   f.walk([&](AffineWriteOpInterface storeOp) {
1044     findUnusedStore(storeOp, opsToErase, postDomInfo);
1045   });
1046   // Erase all store op's which don't impact the program
1047   for (auto *op : opsToErase)
1048     op->erase();
1049 
1050   // Check if the store fwd'ed memrefs are now left with only stores and can
1051   // thus be completely deleted. Note: the canonicalize pass should be able
1052   // to do this as well, but we'll do it here since we collected these anyway.
1053   for (auto memref : memrefsToErase) {
1054     // If the memref hasn't been alloc'ed in this function, skip.
1055     Operation *defOp = memref.getDefiningOp();
1056     if (!defOp || !isa<memref::AllocOp>(defOp))
1057       // TODO: if the memref was returned by a 'call' operation, we
1058       // could still erase it if the call had no side-effects.
1059       continue;
1060     if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
1061           return !isa<AffineWriteOpInterface, memref::DeallocOp>(ownerOp);
1062         }))
1063       continue;
1064 
1065     // Erase all stores, the dealloc, and the alloc on the memref.
1066     for (auto *user : llvm::make_early_inc_range(memref.getUsers()))
1067       user->erase();
1068     defOp->erase();
1069   }
1070 }
1071 
1072 // Perform the replacement in `op`.
1073 LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
1074                                              Operation *op,
1075                                              ArrayRef<Value> extraIndices,
1076                                              AffineMap indexRemap,
1077                                              ArrayRef<Value> extraOperands,
1078                                              ArrayRef<Value> symbolOperands,
1079                                              bool allowNonDereferencingOps) {
1080   unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
1081   (void)newMemRefRank; // unused in opt mode
1082   unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
1083   (void)oldMemRefRank; // unused in opt mode
1084   if (indexRemap) {
1085     assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1086            "symbolic operand count mismatch");
1087     assert(indexRemap.getNumInputs() ==
1088            extraOperands.size() + oldMemRefRank + symbolOperands.size());
1089     assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1090   } else {
1091     assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1092   }
1093 
1094   // Assert same elemental type.
1095   assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
1096          newMemRef.getType().cast<MemRefType>().getElementType());
1097 
1098   SmallVector<unsigned, 2> usePositions;
1099   for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
1100     if (opEntry.value() == oldMemRef)
1101       usePositions.push_back(opEntry.index());
1102   }
1103 
1104   // If memref doesn't appear, nothing to do.
1105   if (usePositions.empty())
1106     return success();
1107 
1108   if (usePositions.size() > 1) {
1109     // TODO: extend it for this case when needed (rare).
1110     assert(false && "multiple dereferencing uses in a single op not supported");
1111     return failure();
1112   }
1113 
1114   unsigned memRefOperandPos = usePositions.front();
1115 
1116   OpBuilder builder(op);
1117   // The following checks if op is dereferencing memref and performs the access
1118   // index rewrites.
1119   auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1120   if (!affMapAccInterface) {
1121     if (!allowNonDereferencingOps) {
1122       // Failure: memref used in a non-dereferencing context (potentially
1123       // escapes); no replacement in these cases unless allowNonDereferencingOps
1124       // is set.
1125       return failure();
1126     }
1127     op->setOperand(memRefOperandPos, newMemRef);
1128     return success();
1129   }
1130   // Perform index rewrites for the dereferencing op and then replace the op
1131   NamedAttribute oldMapAttrPair =
1132       affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
1133   AffineMap oldMap = oldMapAttrPair.getValue().cast<AffineMapAttr>().getValue();
1134   unsigned oldMapNumInputs = oldMap.getNumInputs();
1135   SmallVector<Value, 4> oldMapOperands(
1136       op->operand_begin() + memRefOperandPos + 1,
1137       op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
1138 
1139   // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
1140   SmallVector<Value, 4> oldMemRefOperands;
1141   SmallVector<Value, 4> affineApplyOps;
1142   oldMemRefOperands.reserve(oldMemRefRank);
1143   if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
1144     for (auto resultExpr : oldMap.getResults()) {
1145       auto singleResMap = AffineMap::get(oldMap.getNumDims(),
1146                                          oldMap.getNumSymbols(), resultExpr);
1147       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1148                                                 oldMapOperands);
1149       oldMemRefOperands.push_back(afOp);
1150       affineApplyOps.push_back(afOp);
1151     }
1152   } else {
1153     oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
1154   }
1155 
1156   // Construct new indices as a remap of the old ones if a remapping has been
1157   // provided. The indices of a memref come right after it, i.e.,
1158   // at position memRefOperandPos + 1.
1159   SmallVector<Value, 4> remapOperands;
1160   remapOperands.reserve(extraOperands.size() + oldMemRefRank +
1161                         symbolOperands.size());
1162   remapOperands.append(extraOperands.begin(), extraOperands.end());
1163   remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
1164   remapOperands.append(symbolOperands.begin(), symbolOperands.end());
1165 
1166   SmallVector<Value, 4> remapOutputs;
1167   remapOutputs.reserve(oldMemRefRank);
1168 
1169   if (indexRemap &&
1170       indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
1171     // Remapped indices.
1172     for (auto resultExpr : indexRemap.getResults()) {
1173       auto singleResMap = AffineMap::get(
1174           indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
1175       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1176                                                 remapOperands);
1177       remapOutputs.push_back(afOp);
1178       affineApplyOps.push_back(afOp);
1179     }
1180   } else {
1181     // No remapping specified.
1182     remapOutputs.assign(remapOperands.begin(), remapOperands.end());
1183   }
1184 
1185   SmallVector<Value, 4> newMapOperands;
1186   newMapOperands.reserve(newMemRefRank);
1187 
1188   // Prepend 'extraIndices' in 'newMapOperands'.
1189   for (Value extraIndex : extraIndices) {
1190     assert(extraIndex.getDefiningOp()->getNumResults() == 1 &&
1191            "single result op's expected to generate these indices");
1192     assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
1193            "invalid memory op index");
1194     newMapOperands.push_back(extraIndex);
1195   }
1196 
1197   // Append 'remapOutputs' to 'newMapOperands'.
1198   newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
1199 
1200   // Create new fully composed AffineMap for new op to be created.
1201   assert(newMapOperands.size() == newMemRefRank);
1202   auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
1203   // TODO: Avoid creating/deleting temporary AffineApplyOps here.
1204   fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
1205   newMap = simplifyAffineMap(newMap);
1206   canonicalizeMapAndOperands(&newMap, &newMapOperands);
1207   // Remove any affine.apply's that became dead as a result of composition.
1208   for (Value value : affineApplyOps)
1209     if (value.use_empty())
1210       value.getDefiningOp()->erase();
1211 
1212   OperationState state(op->getLoc(), op->getName());
1213   // Construct the new operation using this memref.
1214   state.operands.reserve(op->getNumOperands() + extraIndices.size());
1215   // Insert the non-memref operands.
1216   state.operands.append(op->operand_begin(),
1217                         op->operand_begin() + memRefOperandPos);
1218   // Insert the new memref value.
1219   state.operands.push_back(newMemRef);
1220 
1221   // Insert the new memref map operands.
1222   state.operands.append(newMapOperands.begin(), newMapOperands.end());
1223 
1224   // Insert the remaining operands unmodified.
1225   state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
1226                             oldMapNumInputs,
1227                         op->operand_end());
1228 
1229   // Result types don't change. Both memref's are of the same elemental type.
1230   state.types.reserve(op->getNumResults());
1231   for (auto result : op->getResults())
1232     state.types.push_back(result.getType());
1233 
1234   // Add attribute for 'newMap', other Attributes do not change.
1235   auto newMapAttr = AffineMapAttr::get(newMap);
1236   for (auto namedAttr : op->getAttrs()) {
1237     if (namedAttr.getName() == oldMapAttrPair.getName())
1238       state.attributes.push_back({namedAttr.getName(), newMapAttr});
1239     else
1240       state.attributes.push_back(namedAttr);
1241   }
1242 
1243   // Create the new operation.
1244   auto *repOp = builder.createOperation(state);
1245   op->replaceAllUsesWith(repOp);
1246   op->erase();
1247 
1248   return success();
1249 }
1250 
1251 LogicalResult mlir::replaceAllMemRefUsesWith(
1252     Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
1253     AffineMap indexRemap, ArrayRef<Value> extraOperands,
1254     ArrayRef<Value> symbolOperands, Operation *domOpFilter,
1255     Operation *postDomOpFilter, bool allowNonDereferencingOps,
1256     bool replaceInDeallocOp) {
1257   unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
1258   (void)newMemRefRank; // unused in opt mode
1259   unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
1260   (void)oldMemRefRank;
1261   if (indexRemap) {
1262     assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1263            "symbol operand count mismatch");
1264     assert(indexRemap.getNumInputs() ==
1265            extraOperands.size() + oldMemRefRank + symbolOperands.size());
1266     assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1267   } else {
1268     assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1269   }
1270 
1271   // Assert same elemental type.
1272   assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
1273          newMemRef.getType().cast<MemRefType>().getElementType());
1274 
1275   std::unique_ptr<DominanceInfo> domInfo;
1276   std::unique_ptr<PostDominanceInfo> postDomInfo;
1277   if (domOpFilter)
1278     domInfo =
1279         std::make_unique<DominanceInfo>(domOpFilter->getParentOfType<FuncOp>());
1280 
1281   if (postDomOpFilter)
1282     postDomInfo = std::make_unique<PostDominanceInfo>(
1283         postDomOpFilter->getParentOfType<FuncOp>());
1284 
1285   // Walk all uses of old memref; collect ops to perform replacement. We use a
1286   // DenseSet since an operation could potentially have multiple uses of a
1287   // memref (although rare), and the replacement later is going to erase ops.
1288   DenseSet<Operation *> opsToReplace;
1289   for (auto *op : oldMemRef.getUsers()) {
1290     // Skip this use if it's not dominated by domOpFilter.
1291     if (domOpFilter && !domInfo->dominates(domOpFilter, op))
1292       continue;
1293 
1294     // Skip this use if it's not post-dominated by postDomOpFilter.
1295     if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op))
1296       continue;
1297 
1298     // Skip dealloc's - no replacement is necessary, and a memref replacement
1299     // at other uses doesn't hurt these dealloc's.
1300     if (isa<memref::DeallocOp>(op) && !replaceInDeallocOp)
1301       continue;
1302 
1303     // Check if the memref was used in a non-dereferencing context. It is fine
1304     // for the memref to be used in a non-dereferencing way outside of the
1305     // region where this replacement is happening.
1306     if (!isa<AffineMapAccessInterface>(*op)) {
1307       if (!allowNonDereferencingOps) {
1308         LLVM_DEBUG(llvm::dbgs()
1309                    << "Memref replacement failed: non-deferencing memref op: \n"
1310                    << *op << '\n');
1311         return failure();
1312       }
1313       // Non-dereferencing ops with the MemRefsNormalizable trait are
1314       // supported for replacement.
1315       if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) {
1316         LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a "
1317                                    "memrefs normalizable trait: \n"
1318                                 << *op << '\n');
1319         return failure();
1320       }
1321     }
1322 
1323     // We'll first collect and then replace --- since replacement erases the op
1324     // that has the use, and that op could be postDomFilter or domFilter itself!
1325     opsToReplace.insert(op);
1326   }
1327 
1328   for (auto *op : opsToReplace) {
1329     if (failed(replaceAllMemRefUsesWith(
1330             oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
1331             symbolOperands, allowNonDereferencingOps)))
1332       llvm_unreachable("memref replacement guaranteed to succeed here");
1333   }
1334 
1335   return success();
1336 }
1337 
1338 /// Given an operation, inserts one or more single result affine
1339 /// apply operations, results of which are exclusively used by this operation
1340 /// operation. The operands of these newly created affine apply ops are
1341 /// guaranteed to be loop iterators or terminal symbols of a function.
1342 ///
1343 /// Before
1344 ///
1345 /// affine.for %i = 0 to #map(%N)
1346 ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1347 ///   "send"(%idx, %A, ...)
1348 ///   "compute"(%idx)
1349 ///
1350 /// After
1351 ///
1352 /// affine.for %i = 0 to #map(%N)
1353 ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1354 ///   "send"(%idx, %A, ...)
1355 ///   %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
1356 ///   "compute"(%idx_)
1357 ///
1358 /// This allows applying different transformations on send and compute (for eg.
1359 /// different shifts/delays).
1360 ///
1361 /// Returns nullptr either if none of opInst's operands were the result of an
1362 /// affine.apply and thus there was no affine computation slice to create, or if
1363 /// all the affine.apply op's supplying operands to this opInst did not have any
1364 /// uses besides this opInst; otherwise returns the list of affine.apply
1365 /// operations created in output argument `sliceOps`.
1366 void mlir::createAffineComputationSlice(
1367     Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
1368   // Collect all operands that are results of affine apply ops.
1369   SmallVector<Value, 4> subOperands;
1370   subOperands.reserve(opInst->getNumOperands());
1371   for (auto operand : opInst->getOperands())
1372     if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp()))
1373       subOperands.push_back(operand);
1374 
1375   // Gather sequence of AffineApplyOps reachable from 'subOperands'.
1376   SmallVector<Operation *, 4> affineApplyOps;
1377   getReachableAffineApplyOps(subOperands, affineApplyOps);
1378   // Skip transforming if there are no affine maps to compose.
1379   if (affineApplyOps.empty())
1380     return;
1381 
1382   // Check if all uses of the affine apply op's lie only in this op op, in
1383   // which case there would be nothing to do.
1384   bool localized = true;
1385   for (auto *op : affineApplyOps) {
1386     for (auto result : op->getResults()) {
1387       for (auto *user : result.getUsers()) {
1388         if (user != opInst) {
1389           localized = false;
1390           break;
1391         }
1392       }
1393     }
1394   }
1395   if (localized)
1396     return;
1397 
1398   OpBuilder builder(opInst);
1399   SmallVector<Value, 4> composedOpOperands(subOperands);
1400   auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
1401   fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
1402 
1403   // Create an affine.apply for each of the map results.
1404   sliceOps->reserve(composedMap.getNumResults());
1405   for (auto resultExpr : composedMap.getResults()) {
1406     auto singleResMap = AffineMap::get(composedMap.getNumDims(),
1407                                        composedMap.getNumSymbols(), resultExpr);
1408     sliceOps->push_back(builder.create<AffineApplyOp>(
1409         opInst->getLoc(), singleResMap, composedOpOperands));
1410   }
1411 
1412   // Construct the new operands that include the results from the composed
1413   // affine apply op above instead of existing ones (subOperands). So, they
1414   // differ from opInst's operands only for those operands in 'subOperands', for
1415   // which they will be replaced by the corresponding one from 'sliceOps'.
1416   SmallVector<Value, 4> newOperands(opInst->getOperands());
1417   for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
1418     // Replace the subOperands from among the new operands.
1419     unsigned j, f;
1420     for (j = 0, f = subOperands.size(); j < f; j++) {
1421       if (newOperands[i] == subOperands[j])
1422         break;
1423     }
1424     if (j < subOperands.size()) {
1425       newOperands[i] = (*sliceOps)[j];
1426     }
1427   }
1428   for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
1429     opInst->setOperand(idx, newOperands[idx]);
1430   }
1431 }
1432 
1433 /// Enum to set patterns of affine expr in tiled-layout map.
1434 /// TileFloorDiv: <dim expr> div <tile size>
1435 /// TileMod: <dim expr> mod <tile size>
1436 /// TileNone: None of the above
1437 /// Example:
1438 /// #tiled_2d_128x256 = affine_map<(d0, d1)
1439 ///            -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1440 /// "d0 div 128" and "d1 div 256" ==> TileFloorDiv
1441 /// "d0 mod 128" and "d1 mod 256" ==> TileMod
1442 enum TileExprPattern { TileFloorDiv, TileMod, TileNone };
1443 
1444 /// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions
1445 /// being floordiv'ed by respective tile sizes appeare in a mod with the same
1446 /// tile sizes, and no other expression involves those k dimensions. This
1447 /// function stores a vector of tuples (`tileSizePos`) including AffineExpr for
1448 /// tile size, positions of corresponding `floordiv` and `mod`. If it is not a
1449 /// tiled layout, an empty vector is returned.
1450 static LogicalResult getTileSizePos(
1451     AffineMap map,
1452     SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) {
1453   // Create `floordivExprs` which is a vector of tuples including LHS and RHS of
1454   // `floordiv` and its position in `map` output.
1455   // Example: #tiled_2d_128x256 = affine_map<(d0, d1)
1456   //                -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1457   // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}.
1458   SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs;
1459   unsigned pos = 0;
1460   for (AffineExpr expr : map.getResults()) {
1461     if (expr.getKind() == AffineExprKind::FloorDiv) {
1462       AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
1463       if (binaryExpr.getRHS().isa<AffineConstantExpr>())
1464         floordivExprs.emplace_back(
1465             std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos));
1466     }
1467     pos++;
1468   }
1469   // Not tiled layout if `floordivExprs` is empty.
1470   if (floordivExprs.empty()) {
1471     tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1472     return success();
1473   }
1474 
1475   // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is
1476   // not tiled layout.
1477   for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) {
1478     AffineExpr floordivExprLHS = std::get<0>(fexpr);
1479     AffineExpr floordivExprRHS = std::get<1>(fexpr);
1480     unsigned floordivPos = std::get<2>(fexpr);
1481 
1482     // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS
1483     // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used
1484     // other expr, the map is not tiled layout. Example of non tiled layout:
1485     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)>
1486     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)>
1487     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod
1488     //   256)>
1489     bool found = false;
1490     pos = 0;
1491     for (AffineExpr expr : map.getResults()) {
1492       bool notTiled = false;
1493       if (pos != floordivPos) {
1494         expr.walk([&](AffineExpr e) {
1495           if (e == floordivExprLHS) {
1496             if (expr.getKind() == AffineExprKind::Mod) {
1497               AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
1498               // If LHS and RHS of `mod` are the same with those of floordiv.
1499               if (floordivExprLHS == binaryExpr.getLHS() &&
1500                   floordivExprRHS == binaryExpr.getRHS()) {
1501                 // Save tile size (RHS of `mod`), and position of `floordiv` and
1502                 // `mod` if same expr with `mod` is not found yet.
1503                 if (!found) {
1504                   tileSizePos.emplace_back(
1505                       std::make_tuple(binaryExpr.getRHS(), floordivPos, pos));
1506                   found = true;
1507                 } else {
1508                   // Non tiled layout: Have multilpe `mod` with the same LHS.
1509                   // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1510                   // mod 256, d2 mod 256)>
1511                   notTiled = true;
1512                 }
1513               } else {
1514                 // Non tiled layout: RHS of `mod` is different from `floordiv`.
1515                 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1516                 // mod 128)>
1517                 notTiled = true;
1518               }
1519             } else {
1520               // Non tiled layout: LHS is the same, but not `mod`.
1521               // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1522               // floordiv 256)>
1523               notTiled = true;
1524             }
1525           }
1526         });
1527       }
1528       if (notTiled) {
1529         tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1530         return success();
1531       }
1532       pos++;
1533     }
1534   }
1535   return success();
1536 }
1537 
1538 /// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic
1539 /// after normalization. Dimensions that include dynamic dimensions in the map
1540 /// output will become dynamic dimensions. Return true if `dim` is dynamic
1541 /// dimension.
1542 ///
1543 /// Example:
1544 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1545 ///
1546 /// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic.
1547 /// memref<4x?xf32, #map0>  ==>  memref<4x?x?xf32>
1548 static bool
1549 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
1550                              SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
1551                              MLIRContext *context) {
1552   bool isDynamicDim = false;
1553   AffineExpr expr = layoutMap.getResults()[dim];
1554   // Check if affine expr of the dimension includes dynamic dimension of input
1555   // memrefType.
1556   expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
1557     if (e.isa<AffineDimExpr>()) {
1558       for (unsigned dm : inMemrefTypeDynDims) {
1559         if (e == getAffineDimExpr(dm, context)) {
1560           isDynamicDim = true;
1561         }
1562       }
1563     }
1564   });
1565   return isDynamicDim;
1566 }
1567 
1568 /// Create affine expr to calculate dimension size for a tiled-layout map.
1569 static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
1570                                                   TileExprPattern pat) {
1571   // Create map output for the patterns.
1572   // "floordiv <tile size>" ==> "ceildiv <tile size>"
1573   // "mod <tile size>" ==> "<tile size>"
1574   AffineExpr newMapOutput;
1575   AffineBinaryOpExpr binaryExpr = nullptr;
1576   switch (pat) {
1577   case TileExprPattern::TileMod:
1578     binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
1579     newMapOutput = binaryExpr.getRHS();
1580     break;
1581   case TileExprPattern::TileFloorDiv:
1582     binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
1583     newMapOutput = getAffineBinaryOpExpr(
1584         AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS());
1585     break;
1586   default:
1587     newMapOutput = oldMapOutput;
1588   }
1589   return newMapOutput;
1590 }
1591 
1592 /// Create new maps to calculate each dimension size of `newMemRefType`, and
1593 /// create `newDynamicSizes` from them by using AffineApplyOp.
1594 ///
1595 /// Steps for normalizing dynamic memrefs for a tiled layout map
1596 /// Example:
1597 ///    #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1598 ///    %0 = dim %arg0, %c1 :memref<4x?xf32>
1599 ///    %1 = alloc(%0) : memref<4x?xf32, #map0>
1600 ///
1601 /// (Before this function)
1602 /// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only
1603 /// single layout map is supported.
1604 ///
1605 /// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It
1606 /// is memref<4x?x?xf32> in the above example.
1607 ///
1608 /// (In this function)
1609 /// 3. Create new maps to calculate each dimension of the normalized memrefType
1610 /// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the
1611 /// dimension size can be calculated by replacing "floordiv <tile size>" with
1612 /// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>".
1613 /// - New map in the above example
1614 ///   #map0 = affine_map<(d0, d1) -> (d0)>
1615 ///   #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)>
1616 ///   #map2 = affine_map<(d0, d1) -> (32)>
1617 ///
1618 /// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp
1619 /// is used in dynamicSizes of new AllocOp.
1620 ///   %0 = dim %arg0, %c1 : memref<4x?xf32>
1621 ///   %c4 = arith.constant 4 : index
1622 ///   %1 = affine.apply #map1(%c4, %0)
1623 ///   %2 = affine.apply #map2(%c4, %0)
1624 static void createNewDynamicSizes(MemRefType oldMemRefType,
1625                                   MemRefType newMemRefType, AffineMap map,
1626                                   memref::AllocOp *allocOp, OpBuilder b,
1627                                   SmallVectorImpl<Value> &newDynamicSizes) {
1628   // Create new input for AffineApplyOp.
1629   SmallVector<Value, 4> inAffineApply;
1630   ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape();
1631   unsigned dynIdx = 0;
1632   for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) {
1633     if (oldMemRefShape[d] < 0) {
1634       // Use dynamicSizes of allocOp for dynamic dimension.
1635       inAffineApply.emplace_back(allocOp->dynamicSizes()[dynIdx]);
1636       dynIdx++;
1637     } else {
1638       // Create ConstantOp for static dimension.
1639       Attribute constantAttr =
1640           b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
1641       inAffineApply.emplace_back(
1642           b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr));
1643     }
1644   }
1645 
1646   // Create new map to calculate each dimension size of new memref for each
1647   // original map output. Only for dynamic dimesion of `newMemRefType`.
1648   unsigned newDimIdx = 0;
1649   ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape();
1650   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1651   (void)getTileSizePos(map, tileSizePos);
1652   for (AffineExpr expr : map.getResults()) {
1653     if (newMemRefShape[newDimIdx] < 0) {
1654       // Create new maps to calculate each dimension size of new memref.
1655       enum TileExprPattern pat = TileExprPattern::TileNone;
1656       for (auto pos : tileSizePos) {
1657         if (newDimIdx == std::get<1>(pos))
1658           pat = TileExprPattern::TileFloorDiv;
1659         else if (newDimIdx == std::get<2>(pos))
1660           pat = TileExprPattern::TileMod;
1661       }
1662       AffineExpr newMapOutput = createDimSizeExprForTiledLayout(expr, pat);
1663       AffineMap newMap =
1664           AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput);
1665       Value affineApp =
1666           b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply);
1667       newDynamicSizes.emplace_back(affineApp);
1668     }
1669     newDimIdx++;
1670   }
1671 }
1672 
1673 // TODO: Currently works for static memrefs with a single layout map.
1674 LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) {
1675   MemRefType memrefType = allocOp->getType();
1676   OpBuilder b(*allocOp);
1677 
1678   // Fetch a new memref type after normalizing the old memref to have an
1679   // identity map layout.
1680   MemRefType newMemRefType =
1681       normalizeMemRefType(memrefType, b, allocOp->symbolOperands().size());
1682   if (newMemRefType == memrefType)
1683     // Either memrefType already had an identity map or the map couldn't be
1684     // transformed to an identity map.
1685     return failure();
1686 
1687   Value oldMemRef = allocOp->getResult();
1688 
1689   SmallVector<Value, 4> symbolOperands(allocOp->symbolOperands());
1690   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1691   memref::AllocOp newAlloc;
1692   // Check if `layoutMap` is a tiled layout. Only single layout map is
1693   // supported for normalizing dynamic memrefs.
1694   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1695   (void)getTileSizePos(layoutMap, tileSizePos);
1696   if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
1697     MemRefType oldMemRefType = oldMemRef.getType().cast<MemRefType>();
1698     SmallVector<Value, 4> newDynamicSizes;
1699     createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
1700                           newDynamicSizes);
1701     // Add the new dynamic sizes in new AllocOp.
1702     newAlloc =
1703         b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
1704                                   newDynamicSizes, allocOp->alignmentAttr());
1705   } else {
1706     newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
1707                                          allocOp->alignmentAttr());
1708   }
1709   // Replace all uses of the old memref.
1710   if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
1711                                       /*extraIndices=*/{},
1712                                       /*indexRemap=*/layoutMap,
1713                                       /*extraOperands=*/{},
1714                                       /*symbolOperands=*/symbolOperands,
1715                                       /*domOpFilter=*/nullptr,
1716                                       /*postDomOpFilter=*/nullptr,
1717                                       /*allowNonDereferencingOps=*/true))) {
1718     // If it failed (due to escapes for example), bail out.
1719     newAlloc.erase();
1720     return failure();
1721   }
1722   // Replace any uses of the original alloc op and erase it. All remaining uses
1723   // have to be dealloc's; RAMUW above would've failed otherwise.
1724   assert(llvm::all_of(oldMemRef.getUsers(), [](Operation *op) {
1725     return isa<memref::DeallocOp>(op);
1726   }));
1727   oldMemRef.replaceAllUsesWith(newAlloc);
1728   allocOp->erase();
1729   return success();
1730 }
1731 
1732 MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
1733                                      unsigned numSymbolicOperands) {
1734   unsigned rank = memrefType.getRank();
1735   if (rank == 0)
1736     return memrefType;
1737 
1738   if (memrefType.getLayout().isIdentity()) {
1739     // Either no maps is associated with this memref or this memref has
1740     // a trivial (identity) map.
1741     return memrefType;
1742   }
1743   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1744 
1745   // We don't do any checks for one-to-one'ness; we assume that it is
1746   // one-to-one.
1747 
1748   // Normalize only static memrefs and dynamic memrefs with a tiled-layout map
1749   // for now.
1750   // TODO: Normalize the other types of dynamic memrefs.
1751   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1752   (void)getTileSizePos(layoutMap, tileSizePos);
1753   if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
1754     return memrefType;
1755 
1756   // We have a single map that is not an identity map. Create a new memref
1757   // with the right shape and an identity layout map.
1758   ArrayRef<int64_t> shape = memrefType.getShape();
1759   // FlatAffineConstraint may later on use symbolicOperands.
1760   FlatAffineConstraints fac(rank, numSymbolicOperands);
1761   SmallVector<unsigned, 4> memrefTypeDynDims;
1762   for (unsigned d = 0; d < rank; ++d) {
1763     // Use constraint system only in static dimensions.
1764     if (shape[d] > 0) {
1765       fac.addBound(FlatAffineConstraints::LB, d, 0);
1766       fac.addBound(FlatAffineConstraints::UB, d, shape[d] - 1);
1767     } else {
1768       memrefTypeDynDims.emplace_back(d);
1769     }
1770   }
1771   // We compose this map with the original index (logical) space to derive
1772   // the upper bounds for the new index space.
1773   unsigned newRank = layoutMap.getNumResults();
1774   if (failed(fac.composeMatchingMap(layoutMap)))
1775     return memrefType;
1776   // TODO: Handle semi-affine maps.
1777   // Project out the old data dimensions.
1778   fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds());
1779   SmallVector<int64_t, 4> newShape(newRank);
1780   for (unsigned d = 0; d < newRank; ++d) {
1781     // Check if each dimension of normalized memrefType is dynamic.
1782     bool isDynDim = isNormalizedMemRefDynamicDim(
1783         d, layoutMap, memrefTypeDynDims, b.getContext());
1784     if (isDynDim) {
1785       newShape[d] = -1;
1786     } else {
1787       // The lower bound for the shape is always zero.
1788       auto ubConst = fac.getConstantBound(FlatAffineConstraints::UB, d);
1789       // For a static memref and an affine map with no symbols, this is
1790       // always bounded.
1791       assert(ubConst.hasValue() && "should always have an upper bound");
1792       if (ubConst.getValue() < 0)
1793         // This is due to an invalid map that maps to a negative space.
1794         return memrefType;
1795       // If dimension of new memrefType is dynamic, the value is -1.
1796       newShape[d] = ubConst.getValue() + 1;
1797     }
1798   }
1799 
1800   // Create the new memref type after trivializing the old layout map.
1801   MemRefType newMemRefType =
1802       MemRefType::Builder(memrefType)
1803           .setShape(newShape)
1804           .setLayout(AffineMapAttr::get(b.getMultiDimIdentityMap(newRank)));
1805 
1806   return newMemRefType;
1807 }
1808