1 //===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous transformation utilities for the Affine
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/Utils.h"
15 
16 #include "mlir/Dialect/Affine/Analysis/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/IR/AffineExprVisitor.h"
23 #include "mlir/IR/BlockAndValueMapping.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/IntegerSet.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 #define DEBUG_TYPE "affine-utils"
29 
30 using namespace mlir;
31 using namespace presburger;
32 
33 namespace {
34 /// Visit affine expressions recursively and build the sequence of operations
35 /// that correspond to it.  Visitation functions return an Value of the
36 /// expression subtree they visited or `nullptr` on error.
37 class AffineApplyExpander
38     : public AffineExprVisitor<AffineApplyExpander, Value> {
39 public:
40   /// This internal class expects arguments to be non-null, checks must be
41   /// performed at the call site.
AffineApplyExpander(OpBuilder & builder,ValueRange dimValues,ValueRange symbolValues,Location loc)42   AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
43                       ValueRange symbolValues, Location loc)
44       : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
45         loc(loc) {}
46 
47   template <typename OpTy>
buildBinaryExpr(AffineBinaryOpExpr expr)48   Value buildBinaryExpr(AffineBinaryOpExpr expr) {
49     auto lhs = visit(expr.getLHS());
50     auto rhs = visit(expr.getRHS());
51     if (!lhs || !rhs)
52       return nullptr;
53     auto op = builder.create<OpTy>(loc, lhs, rhs);
54     return op.getResult();
55   }
56 
visitAddExpr(AffineBinaryOpExpr expr)57   Value visitAddExpr(AffineBinaryOpExpr expr) {
58     return buildBinaryExpr<arith::AddIOp>(expr);
59   }
60 
visitMulExpr(AffineBinaryOpExpr expr)61   Value visitMulExpr(AffineBinaryOpExpr expr) {
62     return buildBinaryExpr<arith::MulIOp>(expr);
63   }
64 
65   /// Euclidean modulo operation: negative RHS is not allowed.
66   /// Remainder of the euclidean integer division is always non-negative.
67   ///
68   /// Implemented as
69   ///
70   ///     a mod b =
71   ///         let remainder = srem a, b;
72   ///             negative = a < 0 in
73   ///         select negative, remainder + b, remainder.
visitModExpr(AffineBinaryOpExpr expr)74   Value visitModExpr(AffineBinaryOpExpr expr) {
75     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
76     if (!rhsConst) {
77       emitError(
78           loc,
79           "semi-affine expressions (modulo by non-const) are not supported");
80       return nullptr;
81     }
82     if (rhsConst.getValue() <= 0) {
83       emitError(loc, "modulo by non-positive value is not supported");
84       return nullptr;
85     }
86 
87     auto lhs = visit(expr.getLHS());
88     auto rhs = visit(expr.getRHS());
89     assert(lhs && rhs && "unexpected affine expr lowering failure");
90 
91     Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
92     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
93     Value isRemainderNegative = builder.create<arith::CmpIOp>(
94         loc, arith::CmpIPredicate::slt, remainder, zeroCst);
95     Value correctedRemainder =
96         builder.create<arith::AddIOp>(loc, remainder, rhs);
97     Value result = builder.create<arith::SelectOp>(
98         loc, isRemainderNegative, correctedRemainder, remainder);
99     return result;
100   }
101 
102   /// Floor division operation (rounds towards negative infinity).
103   ///
104   /// For positive divisors, it can be implemented without branching and with a
105   /// single division operation as
106   ///
107   ///        a floordiv b =
108   ///            let negative = a < 0 in
109   ///            let absolute = negative ? -a - 1 : a in
110   ///            let quotient = absolute / b in
111   ///                negative ? -quotient - 1 : quotient
visitFloorDivExpr(AffineBinaryOpExpr expr)112   Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
113     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
114     if (!rhsConst) {
115       emitError(
116           loc,
117           "semi-affine expressions (division by non-const) are not supported");
118       return nullptr;
119     }
120     if (rhsConst.getValue() <= 0) {
121       emitError(loc, "division by non-positive value is not supported");
122       return nullptr;
123     }
124 
125     auto lhs = visit(expr.getLHS());
126     auto rhs = visit(expr.getRHS());
127     assert(lhs && rhs && "unexpected affine expr lowering failure");
128 
129     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
130     Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1);
131     Value negative = builder.create<arith::CmpIOp>(
132         loc, arith::CmpIPredicate::slt, lhs, zeroCst);
133     Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
134     Value dividend =
135         builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs);
136     Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
137     Value correctedQuotient =
138         builder.create<arith::SubIOp>(loc, noneCst, quotient);
139     Value result = builder.create<arith::SelectOp>(loc, negative,
140                                                    correctedQuotient, quotient);
141     return result;
142   }
143 
144   /// Ceiling division operation (rounds towards positive infinity).
145   ///
146   /// For positive divisors, it can be implemented without branching and with a
147   /// single division operation as
148   ///
149   ///     a ceildiv b =
150   ///         let negative = a <= 0 in
151   ///         let absolute = negative ? -a : a - 1 in
152   ///         let quotient = absolute / b in
153   ///             negative ? -quotient : quotient + 1
visitCeilDivExpr(AffineBinaryOpExpr expr)154   Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
155     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
156     if (!rhsConst) {
157       emitError(loc) << "semi-affine expressions (division by non-const) are "
158                         "not supported";
159       return nullptr;
160     }
161     if (rhsConst.getValue() <= 0) {
162       emitError(loc, "division by non-positive value is not supported");
163       return nullptr;
164     }
165     auto lhs = visit(expr.getLHS());
166     auto rhs = visit(expr.getRHS());
167     assert(lhs && rhs && "unexpected affine expr lowering failure");
168 
169     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
170     Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1);
171     Value nonPositive = builder.create<arith::CmpIOp>(
172         loc, arith::CmpIPredicate::sle, lhs, zeroCst);
173     Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
174     Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
175     Value dividend =
176         builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented);
177     Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
178     Value negatedQuotient =
179         builder.create<arith::SubIOp>(loc, zeroCst, quotient);
180     Value incrementedQuotient =
181         builder.create<arith::AddIOp>(loc, quotient, oneCst);
182     Value result = builder.create<arith::SelectOp>(
183         loc, nonPositive, negatedQuotient, incrementedQuotient);
184     return result;
185   }
186 
visitConstantExpr(AffineConstantExpr expr)187   Value visitConstantExpr(AffineConstantExpr expr) {
188     auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
189     return op.getResult();
190   }
191 
visitDimExpr(AffineDimExpr expr)192   Value visitDimExpr(AffineDimExpr expr) {
193     assert(expr.getPosition() < dimValues.size() &&
194            "affine dim position out of range");
195     return dimValues[expr.getPosition()];
196   }
197 
visitSymbolExpr(AffineSymbolExpr expr)198   Value visitSymbolExpr(AffineSymbolExpr expr) {
199     assert(expr.getPosition() < symbolValues.size() &&
200            "symbol dim position out of range");
201     return symbolValues[expr.getPosition()];
202   }
203 
204 private:
205   OpBuilder &builder;
206   ValueRange dimValues;
207   ValueRange symbolValues;
208 
209   Location loc;
210 };
211 } // namespace
212 
213 /// Create a sequence of operations that implement the `expr` applied to the
214 /// given dimension and symbol values.
expandAffineExpr(OpBuilder & builder,Location loc,AffineExpr expr,ValueRange dimValues,ValueRange symbolValues)215 mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc,
216                                    AffineExpr expr, ValueRange dimValues,
217                                    ValueRange symbolValues) {
218   return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
219 }
220 
221 /// Create a sequence of operations that implement the `affineMap` applied to
222 /// the given `operands` (as it it were an AffineApplyOp).
expandAffineMap(OpBuilder & builder,Location loc,AffineMap affineMap,ValueRange operands)223 Optional<SmallVector<Value, 8>> mlir::expandAffineMap(OpBuilder &builder,
224                                                       Location loc,
225                                                       AffineMap affineMap,
226                                                       ValueRange operands) {
227   auto numDims = affineMap.getNumDims();
228   auto expanded = llvm::to_vector<8>(
229       llvm::map_range(affineMap.getResults(),
230                       [numDims, &builder, loc, operands](AffineExpr expr) {
231                         return expandAffineExpr(builder, loc, expr,
232                                                 operands.take_front(numDims),
233                                                 operands.drop_front(numDims));
234                       }));
235   if (llvm::all_of(expanded, [](Value v) { return v; }))
236     return expanded;
237   return None;
238 }
239 
240 /// Promotes the `then` or the `else` block of `ifOp` (depending on whether
241 /// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
242 /// the rest of the op.
promoteIfBlock(AffineIfOp ifOp,bool elseBlock)243 static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
244   if (elseBlock)
245     assert(ifOp.hasElse() && "else block expected");
246 
247   Block *destBlock = ifOp->getBlock();
248   Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
249   destBlock->getOperations().splice(
250       Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(),
251       std::prev(srcBlock->end()));
252   ifOp.erase();
253 }
254 
255 /// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
256 /// on. The `ifOp` could be hoisted and placed right before such an operation.
257 /// This method assumes that the ifOp has been canonicalized (to be correct and
258 /// effective).
getOutermostInvariantForOp(AffineIfOp ifOp)259 static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
260   // Walk up the parents past all for op that this conditional is invariant on.
261   auto ifOperands = ifOp.getOperands();
262   auto *res = ifOp.getOperation();
263   while (!isa<func::FuncOp>(res->getParentOp())) {
264     auto *parentOp = res->getParentOp();
265     if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
266       if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
267         break;
268     } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
269       for (auto iv : parallelOp.getIVs())
270         if (llvm::is_contained(ifOperands, iv))
271           break;
272     } else if (!isa<AffineIfOp>(parentOp)) {
273       // Won't walk up past anything other than affine.for/if ops.
274       break;
275     }
276     // You can always hoist up past any affine.if ops.
277     res = parentOp;
278   }
279   return res;
280 }
281 
282 /// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
283 /// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
284 /// otherwise the same `ifOp`.
hoistAffineIfOp(AffineIfOp ifOp,Operation * hoistOverOp)285 static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
286   // No hoisting to do.
287   if (hoistOverOp == ifOp)
288     return ifOp;
289 
290   // Create the hoisted 'if' first. Then, clone the op we are hoisting over for
291   // the else block. Then drop the else block of the original 'if' in the 'then'
292   // branch while promoting its then block, and analogously drop the 'then'
293   // block of the original 'if' from the 'else' branch while promoting its else
294   // block.
295   BlockAndValueMapping operandMap;
296   OpBuilder b(hoistOverOp);
297   auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
298                                           ifOp.getOperands(),
299                                           /*elseBlock=*/true);
300 
301   // Create a clone of hoistOverOp to use for the else branch of the hoisted
302   // conditional. The else block may get optimized away if empty.
303   Operation *hoistOverOpClone = nullptr;
304   // We use this unique name to identify/find  `ifOp`'s clone in the else
305   // version.
306   StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting");
307   operandMap.clear();
308   b.setInsertionPointAfter(hoistOverOp);
309   // We'll set an attribute to identify this op in a clone of this sub-tree.
310   ifOp->setAttr(idForIfOp, b.getBoolAttr(true));
311   hoistOverOpClone = b.clone(*hoistOverOp, operandMap);
312 
313   // Promote the 'then' block of the original affine.if in the then version.
314   promoteIfBlock(ifOp, /*elseBlock=*/false);
315 
316   // Move the then version to the hoisted if op's 'then' block.
317   auto *thenBlock = hoistedIfOp.getThenBlock();
318   thenBlock->getOperations().splice(thenBlock->begin(),
319                                     hoistOverOp->getBlock()->getOperations(),
320                                     Block::iterator(hoistOverOp));
321 
322   // Find the clone of the original affine.if op in the else version.
323   AffineIfOp ifCloneInElse;
324   hoistOverOpClone->walk([&](AffineIfOp ifClone) {
325     if (!ifClone->getAttr(idForIfOp))
326       return WalkResult::advance();
327     ifCloneInElse = ifClone;
328     return WalkResult::interrupt();
329   });
330   assert(ifCloneInElse && "if op clone should exist");
331   // For the else block, promote the else block of the original 'if' if it had
332   // one; otherwise, the op itself is to be erased.
333   if (!ifCloneInElse.hasElse())
334     ifCloneInElse.erase();
335   else
336     promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
337 
338   // Move the else version into the else block of the hoisted if op.
339   auto *elseBlock = hoistedIfOp.getElseBlock();
340   elseBlock->getOperations().splice(
341       elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
342       Block::iterator(hoistOverOpClone));
343 
344   return hoistedIfOp;
345 }
346 
347 LogicalResult
affineParallelize(AffineForOp forOp,ArrayRef<LoopReduction> parallelReductions)348 mlir::affineParallelize(AffineForOp forOp,
349                         ArrayRef<LoopReduction> parallelReductions) {
350   // Fail early if there are iter arguments that are not reductions.
351   unsigned numReductions = parallelReductions.size();
352   if (numReductions != forOp.getNumIterOperands())
353     return failure();
354 
355   Location loc = forOp.getLoc();
356   OpBuilder outsideBuilder(forOp);
357   AffineMap lowerBoundMap = forOp.getLowerBoundMap();
358   ValueRange lowerBoundOperands = forOp.getLowerBoundOperands();
359   AffineMap upperBoundMap = forOp.getUpperBoundMap();
360   ValueRange upperBoundOperands = forOp.getUpperBoundOperands();
361 
362   // Creating empty 1-D affine.parallel op.
363   auto reducedValues = llvm::to_vector<4>(llvm::map_range(
364       parallelReductions, [](const LoopReduction &red) { return red.value; }));
365   auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
366       parallelReductions, [](const LoopReduction &red) { return red.kind; }));
367   AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
368       loc, ValueRange(reducedValues).getTypes(), reductionKinds,
369       llvm::makeArrayRef(lowerBoundMap), lowerBoundOperands,
370       llvm::makeArrayRef(upperBoundMap), upperBoundOperands,
371       llvm::makeArrayRef(forOp.getStep()));
372   // Steal the body of the old affine for op.
373   newPloop.getRegion().takeBody(forOp.getRegion());
374   Operation *yieldOp = &newPloop.getBody()->back();
375 
376   // Handle the initial values of reductions because the parallel loop always
377   // starts from the neutral value.
378   SmallVector<Value> newResults;
379   newResults.reserve(numReductions);
380   for (unsigned i = 0; i < numReductions; ++i) {
381     Value init = forOp.getIterOperands()[i];
382     // This works because we are only handling single-op reductions at the
383     // moment. A switch on reduction kind or a mechanism to collect operations
384     // participating in the reduction will be necessary for multi-op reductions.
385     Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp();
386     assert(reductionOp && "yielded value is expected to be produced by an op");
387     outsideBuilder.getInsertionBlock()->getOperations().splice(
388         outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(),
389         reductionOp);
390     reductionOp->setOperands({init, newPloop->getResult(i)});
391     forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0));
392   }
393 
394   // Update the loop terminator to yield reduced values bypassing the reduction
395   // operation itself (now moved outside of the loop) and erase the block
396   // arguments that correspond to reductions. Note that the loop always has one
397   // "main" induction variable whenc coming from a non-parallel for.
398   unsigned numIVs = 1;
399   yieldOp->setOperands(reducedValues);
400   newPloop.getBody()->eraseArguments(
401       llvm::to_vector<4>(llvm::seq<unsigned>(numIVs, numReductions + numIVs)));
402 
403   forOp.erase();
404   return success();
405 }
406 
407 // Returns success if any hoisting happened.
hoistAffineIfOp(AffineIfOp ifOp,bool * folded)408 LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
409   // Bail out early if the ifOp returns a result.  TODO: Consider how to
410   // properly support this case.
411   if (ifOp.getNumResults() != 0)
412     return failure();
413 
414   // Apply canonicalization patterns and folding - this is necessary for the
415   // hoisting check to be correct (operands should be composed), and to be more
416   // effective (no unused operands). Since the pattern rewriter's folding is
417   // entangled with application of patterns, we may fold/end up erasing the op,
418   // in which case we return with `folded` being set.
419   RewritePatternSet patterns(ifOp.getContext());
420   AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
421   bool erased;
422   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
423   (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased);
424   if (erased) {
425     if (folded)
426       *folded = true;
427     return failure();
428   }
429   if (folded)
430     *folded = false;
431 
432   // The folding above should have ensured this, but the affine.if's
433   // canonicalization is missing composition of affine.applys into it.
434   assert(llvm::all_of(ifOp.getOperands(),
435                       [](Value v) {
436                         return isTopLevelValue(v) || isForInductionVar(v);
437                       }) &&
438          "operands not composed");
439 
440   // We are going hoist as high as possible.
441   // TODO: this could be customized in the future.
442   auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
443 
444   AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
445   // Nothing to hoist over.
446   if (hoistedIfOp == ifOp)
447     return failure();
448 
449   // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
450   // a sequence of affine.fors that are all perfectly nested).
451   (void)applyPatternsAndFoldGreedily(
452       hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
453       frozenPatterns);
454 
455   return success();
456 }
457 
458 // Return the min expr after replacing the given dim.
substWithMin(AffineExpr e,AffineExpr dim,AffineExpr min,AffineExpr max,bool positivePath)459 AffineExpr mlir::substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
460                               AffineExpr max, bool positivePath) {
461   if (e == dim)
462     return positivePath ? min : max;
463   if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) {
464     AffineExpr lhs = bin.getLHS();
465     AffineExpr rhs = bin.getRHS();
466     if (bin.getKind() == mlir::AffineExprKind::Add)
467       return substWithMin(lhs, dim, min, max, positivePath) +
468              substWithMin(rhs, dim, min, max, positivePath);
469 
470     auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>();
471     auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>();
472     if (c1 && c1.getValue() < 0)
473       return getAffineBinaryOpExpr(
474           bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath));
475     if (c2 && c2.getValue() < 0)
476       return getAffineBinaryOpExpr(
477           bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2);
478     return getAffineBinaryOpExpr(
479         bin.getKind(), substWithMin(lhs, dim, min, max, positivePath),
480         substWithMin(rhs, dim, min, max, positivePath));
481   }
482   return e;
483 }
484 
normalizeAffineParallel(AffineParallelOp op)485 void mlir::normalizeAffineParallel(AffineParallelOp op) {
486   // Loops with min/max in bounds are not normalized at the moment.
487   if (op.hasMinMaxBounds())
488     return;
489 
490   AffineMap lbMap = op.getLowerBoundsMap();
491   SmallVector<int64_t, 8> steps = op.getSteps();
492   // No need to do any work if the parallel op is already normalized.
493   bool isAlreadyNormalized =
494       llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
495         int64_t step = std::get<0>(tuple);
496         auto lbExpr =
497             std::get<1>(tuple).template dyn_cast<AffineConstantExpr>();
498         return lbExpr && lbExpr.getValue() == 0 && step == 1;
499       });
500   if (isAlreadyNormalized)
501     return;
502 
503   AffineValueMap ranges;
504   AffineValueMap::difference(op.getUpperBoundsValueMap(),
505                              op.getLowerBoundsValueMap(), &ranges);
506   auto builder = OpBuilder::atBlockBegin(op.getBody());
507   auto zeroExpr = builder.getAffineConstantExpr(0);
508   SmallVector<AffineExpr, 8> lbExprs;
509   SmallVector<AffineExpr, 8> ubExprs;
510   for (unsigned i = 0, e = steps.size(); i < e; ++i) {
511     int64_t step = steps[i];
512 
513     // Adjust the lower bound to be 0.
514     lbExprs.push_back(zeroExpr);
515 
516     // Adjust the upper bound expression: 'range / step'.
517     AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step);
518     ubExprs.push_back(ubExpr);
519 
520     // Adjust the corresponding IV: 'lb + i * step'.
521     BlockArgument iv = op.getBody()->getArgument(i);
522     AffineExpr lbExpr = lbMap.getResult(i);
523     unsigned nDims = lbMap.getNumDims();
524     auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step;
525     auto map = AffineMap::get(/*dimCount=*/nDims + 1,
526                               /*symbolCount=*/lbMap.getNumSymbols(), expr);
527 
528     // Use an 'affine.apply' op that will be simplified later in subsequent
529     // canonicalizations.
530     OperandRange lbOperands = op.getLowerBoundsOperands();
531     OperandRange dimOperands = lbOperands.take_front(nDims);
532     OperandRange symbolOperands = lbOperands.drop_front(nDims);
533     SmallVector<Value, 8> applyOperands{dimOperands};
534     applyOperands.push_back(iv);
535     applyOperands.append(symbolOperands.begin(), symbolOperands.end());
536     auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
537     iv.replaceAllUsesExcept(apply, apply);
538   }
539 
540   SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
541   op.setSteps(newSteps);
542   auto newLowerMap = AffineMap::get(
543       /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext());
544   op.setLowerBounds({}, newLowerMap);
545   auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(),
546                                     ubExprs, op.getContext());
547   op.setUpperBounds(ranges.getOperands(), newUpperMap);
548 }
549 
550 /// Normalizes affine.for ops. If the affine.for op has only a single iteration
551 /// only then it is simply promoted, else it is normalized in the traditional
552 /// way, by converting the lower bound to zero and loop step to one. The upper
553 /// bound is set to the trip count of the loop. For now, original loops must
554 /// have lower bound with a single result only. There is no such restriction on
555 /// upper bounds.
normalizeAffineFor(AffineForOp op)556 LogicalResult mlir::normalizeAffineFor(AffineForOp op) {
557   if (succeeded(promoteIfSingleIteration(op)))
558     return success();
559 
560   // Check if the forop is already normalized.
561   if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) &&
562       (op.getStep() == 1))
563     return success();
564 
565   // Check if the lower bound has a single result only. Loops with a max lower
566   // bound can't be normalized without additional support like
567   // affine.execute_region's. If the lower bound does not have a single result
568   // then skip this op.
569   if (op.getLowerBoundMap().getNumResults() != 1)
570     return failure();
571 
572   Location loc = op.getLoc();
573   OpBuilder opBuilder(op);
574   int64_t origLoopStep = op.getStep();
575 
576   // Calculate upperBound for normalized loop.
577   SmallVector<Value, 4> ubOperands;
578   AffineBound lb = op.getLowerBound();
579   AffineBound ub = op.getUpperBound();
580   ubOperands.reserve(ub.getNumOperands() + lb.getNumOperands());
581   AffineMap origLbMap = lb.getMap();
582   AffineMap origUbMap = ub.getMap();
583 
584   // Add dimension operands from upper/lower bound.
585   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
586     ubOperands.push_back(ub.getOperand(j));
587   for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j)
588     ubOperands.push_back(lb.getOperand(j));
589 
590   // Add symbol operands from upper/lower bound.
591   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
592     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
593   for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
594     ubOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
595 
596   // Add original result expressions from lower/upper bound map.
597   SmallVector<AffineExpr, 1> origLbExprs(origLbMap.getResults().begin(),
598                                          origLbMap.getResults().end());
599   SmallVector<AffineExpr, 2> origUbExprs(origUbMap.getResults().begin(),
600                                          origUbMap.getResults().end());
601   SmallVector<AffineExpr, 4> newUbExprs;
602 
603   // The original upperBound can have more than one result. For the new
604   // upperBound of this loop, take difference of all possible combinations of
605   // the ub results and lb result and ceildiv with the loop step. For e.g.,
606   //
607   //  affine.for %i1 = 0 to min affine_map<(d0)[] -> (d0 + 32, 1024)>(%i0)
608   //  will have an upperBound map as,
609   //  affine_map<(d0)[] -> (((d0 + 32) - 0) ceildiv 1, (1024 - 0) ceildiv
610   //  1)>(%i0)
611   //
612   // Insert all combinations of upper/lower bound results.
613   for (unsigned i = 0, e = origUbExprs.size(); i < e; ++i) {
614     newUbExprs.push_back(
615         (origUbExprs[i] - origLbExprs[0]).ceilDiv(origLoopStep));
616   }
617 
618   // Construct newUbMap.
619   AffineMap newUbMap =
620       AffineMap::get(origLbMap.getNumDims() + origUbMap.getNumDims(),
621                      origLbMap.getNumSymbols() + origUbMap.getNumSymbols(),
622                      newUbExprs, opBuilder.getContext());
623   canonicalizeMapAndOperands(&newUbMap, &ubOperands);
624 
625   SmallVector<Value, 4> lbOperands(lb.getOperands().begin(),
626                                    lb.getOperands().begin() +
627                                        lb.getMap().getNumDims());
628 
629   // Normalize the loop.
630   op.setUpperBound(ubOperands, newUbMap);
631   op.setLowerBound({}, opBuilder.getConstantAffineMap(0));
632   op.setStep(1);
633 
634   // Calculate the Value of new loopIV. Create affine.apply for the value of
635   // the loopIV in normalized loop.
636   opBuilder.setInsertionPointToStart(op.getBody());
637   // Add an extra dim operand for loopIV.
638   lbOperands.push_back(op.getInductionVar());
639   // Add symbol operands from lower bound.
640   for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
641     lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
642 
643   AffineExpr origIVExpr = opBuilder.getAffineDimExpr(lb.getMap().getNumDims());
644   AffineExpr newIVExpr = origIVExpr * origLoopStep + origLbMap.getResult(0);
645   AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1,
646                                    origLbMap.getNumSymbols(), newIVExpr);
647   canonicalizeMapAndOperands(&ivMap, &lbOperands);
648   Operation *newIV = opBuilder.create<AffineApplyOp>(loc, ivMap, lbOperands);
649   op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
650   return success();
651 }
652 
653 /// Ensure that all operations that could be executed after `start`
654 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
655 /// between the operations) do not have the potential memory effect
656 /// `EffectType` on `memOp`. `memOp`  is an operation that reads or writes to
657 /// a memref. For example, if `EffectType` is MemoryEffects::Write, this method
658 /// will check if there is no write to the memory between `start` and `memOp`
659 /// that would change the read within `memOp`.
660 template <typename EffectType, typename T>
hasNoInterveningEffect(Operation * start,T memOp)661 static bool hasNoInterveningEffect(Operation *start, T memOp) {
662   Value memref = memOp.getMemRef();
663   bool isOriginalAllocation = memref.getDefiningOp<memref::AllocaOp>() ||
664                               memref.getDefiningOp<memref::AllocOp>();
665 
666   // A boolean representing whether an intervening operation could have impacted
667   // memOp.
668   bool hasSideEffect = false;
669 
670   // Check whether the effect on memOp can be caused by a given operation op.
671   std::function<void(Operation *)> checkOperation = [&](Operation *op) {
672     // If the effect has alreay been found, early exit,
673     if (hasSideEffect)
674       return;
675 
676     if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
677       SmallVector<MemoryEffects::EffectInstance, 1> effects;
678       memEffect.getEffects(effects);
679 
680       bool opMayHaveEffect = false;
681       for (auto effect : effects) {
682         // If op causes EffectType on a potentially aliasing location for
683         // memOp, mark as having the effect.
684         if (isa<EffectType>(effect.getEffect())) {
685           if (isOriginalAllocation && effect.getValue() &&
686               (effect.getValue().getDefiningOp<memref::AllocaOp>() ||
687                effect.getValue().getDefiningOp<memref::AllocOp>())) {
688             if (effect.getValue() != memref)
689               continue;
690           }
691           opMayHaveEffect = true;
692           break;
693         }
694       }
695 
696       if (!opMayHaveEffect)
697         return;
698 
699       // If the side effect comes from an affine read or write, try to
700       // prove the side effecting `op` cannot reach `memOp`.
701       if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
702         MemRefAccess srcAccess(op);
703         MemRefAccess destAccess(memOp);
704         // Dependence analysis is only correct if both ops operate on the same
705         // memref.
706         if (srcAccess.memref == destAccess.memref) {
707           FlatAffineValueConstraints dependenceConstraints;
708 
709           // Number of loops containing the start op and the ending operation.
710           unsigned minSurroundingLoops =
711               getNumCommonSurroundingLoops(*start, *memOp);
712 
713           // Number of loops containing the operation `op` which has the
714           // potential memory side effect and can occur on a path between
715           // `start` and `memOp`.
716           unsigned nsLoops = getNumCommonSurroundingLoops(*op, *memOp);
717 
718           // For ease, let's consider the case that `op` is a store and we're
719           // looking for other potential stores (e.g `op`) that overwrite memory
720           // after `start`, and before being read in `memOp`. In this case, we
721           // only need to consider other potential stores with depth >
722           // minSurrounding loops since `start` would overwrite any store with a
723           // smaller number of surrounding loops before.
724           unsigned d;
725           for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
726             DependenceResult result = checkMemrefAccessDependence(
727                 srcAccess, destAccess, d, &dependenceConstraints,
728                 /*dependenceComponents=*/nullptr);
729             if (hasDependence(result)) {
730               hasSideEffect = true;
731               return;
732             }
733           }
734 
735           // No side effect was seen, simply return.
736           return;
737         }
738       }
739       hasSideEffect = true;
740       return;
741     }
742 
743     if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
744       // Recurse into the regions for this op and check whether the internal
745       // operations may have the side effect `EffectType` on memOp.
746       for (Region &region : op->getRegions())
747         for (Block &block : region)
748           for (Operation &op : block)
749             checkOperation(&op);
750       return;
751     }
752 
753     // Otherwise, conservatively assume generic operations have the effect
754     // on the operation
755     hasSideEffect = true;
756   };
757 
758   // Check all paths from ancestor op `parent` to the operation `to` for the
759   // effect. It is known that `to` must be contained within `parent`.
760   auto until = [&](Operation *parent, Operation *to) {
761     // TODO check only the paths from `parent` to `to`.
762     // Currently we fallback and check the entire parent op, rather than
763     // just the paths from the parent path, stopping after reaching `to`.
764     // This is conservatively correct, but could be made more aggressive.
765     assert(parent->isAncestor(to));
766     checkOperation(parent);
767   };
768 
769   // Check for all paths from operation `from` to operation `untilOp` for the
770   // given memory effect.
771   std::function<void(Operation *, Operation *)> recur =
772       [&](Operation *from, Operation *untilOp) {
773         assert(
774             from->getParentRegion()->isAncestor(untilOp->getParentRegion()) &&
775             "Checking for side effect between two operations without a common "
776             "ancestor");
777 
778         // If the operations are in different regions, recursively consider all
779         // path from `from` to the parent of `to` and all paths from the parent
780         // of `to` to `to`.
781         if (from->getParentRegion() != untilOp->getParentRegion()) {
782           recur(from, untilOp->getParentOp());
783           until(untilOp->getParentOp(), untilOp);
784           return;
785         }
786 
787         // Now, assuming that `from` and `to` exist in the same region, perform
788         // a CFG traversal to check all the relevant operations.
789 
790         // Additional blocks to consider.
791         SmallVector<Block *, 2> todoBlocks;
792         {
793           // First consider the parent block of `from` an check all operations
794           // after `from`.
795           for (auto iter = ++from->getIterator(), end = from->getBlock()->end();
796                iter != end && &*iter != untilOp; ++iter) {
797             checkOperation(&*iter);
798           }
799 
800           // If the parent of `from` doesn't contain `to`, add the successors
801           // to the list of blocks to check.
802           if (untilOp->getBlock() != from->getBlock())
803             for (Block *succ : from->getBlock()->getSuccessors())
804               todoBlocks.push_back(succ);
805         }
806 
807         SmallPtrSet<Block *, 4> done;
808         // Traverse the CFG until hitting `to`.
809         while (!todoBlocks.empty()) {
810           Block *blk = todoBlocks.pop_back_val();
811           if (done.count(blk))
812             continue;
813           done.insert(blk);
814           for (auto &op : *blk) {
815             if (&op == untilOp)
816               break;
817             checkOperation(&op);
818             if (&op == blk->getTerminator())
819               for (Block *succ : blk->getSuccessors())
820                 todoBlocks.push_back(succ);
821           }
822         }
823       };
824   recur(start, memOp);
825   return !hasSideEffect;
826 }
827 
828 /// Attempt to eliminate loadOp by replacing it with a value stored into memory
829 /// which the load is guaranteed to retrieve. This check involves three
830 /// components: 1) The store and load must be on the same location 2) The store
831 /// must dominate (and therefore must always occur prior to) the load 3) No
832 /// other operations will overwrite the memory loaded between the given load
833 /// and store.  If such a value exists, the replaced `loadOp` will be added to
834 /// `loadOpsToErase` and its memref will be added to `memrefsToErase`.
forwardStoreToLoad(AffineReadOpInterface loadOp,SmallVectorImpl<Operation * > & loadOpsToErase,SmallPtrSetImpl<Value> & memrefsToErase,DominanceInfo & domInfo)835 static LogicalResult forwardStoreToLoad(
836     AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase,
837     SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo) {
838 
839   // The store op candidate for forwarding that satisfies all conditions
840   // to replace the load, if any.
841   Operation *lastWriteStoreOp = nullptr;
842 
843   for (auto *user : loadOp.getMemRef().getUsers()) {
844     auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
845     if (!storeOp)
846       continue;
847     MemRefAccess srcAccess(storeOp);
848     MemRefAccess destAccess(loadOp);
849 
850     // 1. Check if the store and the load have mathematically equivalent
851     // affine access functions; this implies that they statically refer to the
852     // same single memref element. As an example this filters out cases like:
853     //     store %A[%i0 + 1]
854     //     load %A[%i0]
855     //     store %A[%M]
856     //     load %A[%N]
857     // Use the AffineValueMap difference based memref access equality checking.
858     if (srcAccess != destAccess)
859       continue;
860 
861     // 2. The store has to dominate the load op to be candidate.
862     if (!domInfo.dominates(storeOp, loadOp))
863       continue;
864 
865     // 3. Ensure there is no intermediate operation which could replace the
866     // value in memory.
867     if (!hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp))
868       continue;
869 
870     // We now have a candidate for forwarding.
871     assert(lastWriteStoreOp == nullptr &&
872            "multiple simulataneous replacement stores");
873     lastWriteStoreOp = storeOp;
874   }
875 
876   if (!lastWriteStoreOp)
877     return failure();
878 
879   // Perform the actual store to load forwarding.
880   Value storeVal =
881       cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
882   // Check if 2 values have the same shape. This is needed for affine vector
883   // loads and stores.
884   if (storeVal.getType() != loadOp.getValue().getType())
885     return failure();
886   loadOp.getValue().replaceAllUsesWith(storeVal);
887   // Record the memref for a later sweep to optimize away.
888   memrefsToErase.insert(loadOp.getMemRef());
889   // Record this to erase later.
890   loadOpsToErase.push_back(loadOp);
891   return success();
892 }
893 
894 // This attempts to find stores which have no impact on the final result.
895 // A writing op writeA will be eliminated if there exists an op writeB if
896 // 1) writeA and writeB have mathematically equivalent affine access functions.
897 // 2) writeB postdominates writeA.
898 // 3) There is no potential read between writeA and writeB.
findUnusedStore(AffineWriteOpInterface writeA,SmallVectorImpl<Operation * > & opsToErase,PostDominanceInfo & postDominanceInfo)899 static void findUnusedStore(AffineWriteOpInterface writeA,
900                             SmallVectorImpl<Operation *> &opsToErase,
901                             PostDominanceInfo &postDominanceInfo) {
902 
903   for (Operation *user : writeA.getMemRef().getUsers()) {
904     // Only consider writing operations.
905     auto writeB = dyn_cast<AffineWriteOpInterface>(user);
906     if (!writeB)
907       continue;
908 
909     // The operations must be distinct.
910     if (writeB == writeA)
911       continue;
912 
913     // Both operations must lie in the same region.
914     if (writeB->getParentRegion() != writeA->getParentRegion())
915       continue;
916 
917     // Both operations must write to the same memory.
918     MemRefAccess srcAccess(writeB);
919     MemRefAccess destAccess(writeA);
920 
921     if (srcAccess != destAccess)
922       continue;
923 
924     // writeB must postdominate writeA.
925     if (!postDominanceInfo.postDominates(writeB, writeA))
926       continue;
927 
928     // There cannot be an operation which reads from memory between
929     // the two writes.
930     if (!hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB))
931       continue;
932 
933     opsToErase.push_back(writeA);
934     break;
935   }
936 }
937 
938 // The load to load forwarding / redundant load elimination is similar to the
939 // store to load forwarding.
940 // loadA will be be replaced with loadB if:
941 // 1) loadA and loadB have mathematically equivalent affine access functions.
942 // 2) loadB dominates loadA.
943 // 3) There is no write between loadA and loadB.
loadCSE(AffineReadOpInterface loadA,SmallVectorImpl<Operation * > & loadOpsToErase,DominanceInfo & domInfo)944 static void loadCSE(AffineReadOpInterface loadA,
945                     SmallVectorImpl<Operation *> &loadOpsToErase,
946                     DominanceInfo &domInfo) {
947   SmallVector<AffineReadOpInterface, 4> loadCandidates;
948   for (auto *user : loadA.getMemRef().getUsers()) {
949     auto loadB = dyn_cast<AffineReadOpInterface>(user);
950     if (!loadB || loadB == loadA)
951       continue;
952 
953     MemRefAccess srcAccess(loadB);
954     MemRefAccess destAccess(loadA);
955 
956     // 1. The accesses have to be to the same location.
957     if (srcAccess != destAccess) {
958       continue;
959     }
960 
961     // 2. The store has to dominate the load op to be candidate.
962     if (!domInfo.dominates(loadB, loadA))
963       continue;
964 
965     // 3. There is no write between loadA and loadB.
966     if (!hasNoInterveningEffect<MemoryEffects::Write>(loadB.getOperation(),
967                                                       loadA))
968       continue;
969 
970     // Check if two values have the same shape. This is needed for affine vector
971     // loads.
972     if (loadB.getValue().getType() != loadA.getValue().getType())
973       continue;
974 
975     loadCandidates.push_back(loadB);
976   }
977 
978   // Of the legal load candidates, use the one that dominates all others
979   // to minimize the subsequent need to loadCSE
980   Value loadB;
981   for (AffineReadOpInterface option : loadCandidates) {
982     if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) {
983           return depStore == option ||
984                  domInfo.dominates(option.getOperation(),
985                                    depStore.getOperation());
986         })) {
987       loadB = option.getValue();
988       break;
989     }
990   }
991 
992   if (loadB) {
993     loadA.getValue().replaceAllUsesWith(loadB);
994     // Record this to erase later.
995     loadOpsToErase.push_back(loadA);
996   }
997 }
998 
999 // The store to load forwarding and load CSE rely on three conditions:
1000 //
1001 // 1) store/load providing a replacement value and load being replaced need to
1002 // have mathematically equivalent affine access functions (checked after full
1003 // composition of load/store operands); this implies that they access the same
1004 // single memref element for all iterations of the common surrounding loop,
1005 //
1006 // 2) the store/load op should dominate the load op,
1007 //
1008 // 3) no operation that may write to memory read by the load being replaced can
1009 // occur after executing the instruction (load or store) providing the
1010 // replacement value and before the load being replaced (thus potentially
1011 // allowing overwriting the memory read by the load).
1012 //
1013 // The above conditions are simple to check, sufficient, and powerful for most
1014 // cases in practice - they are sufficient, but not necessary --- since they
1015 // don't reason about loops that are guaranteed to execute at least once or
1016 // multiple sources to forward from.
1017 //
1018 // TODO: more forwarding can be done when support for
1019 // loop/conditional live-out SSA values is available.
1020 // TODO: do general dead store elimination for memref's. This pass
1021 // currently only eliminates the stores only if no other loads/uses (other
1022 // than dealloc) remain.
1023 //
affineScalarReplace(func::FuncOp f,DominanceInfo & domInfo,PostDominanceInfo & postDomInfo)1024 void mlir::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
1025                                PostDominanceInfo &postDomInfo) {
1026   // Load op's whose results were replaced by those forwarded from stores.
1027   SmallVector<Operation *, 8> opsToErase;
1028 
1029   // A list of memref's that are potentially dead / could be eliminated.
1030   SmallPtrSet<Value, 4> memrefsToErase;
1031 
1032   // Walk all load's and perform store to load forwarding.
1033   f.walk([&](AffineReadOpInterface loadOp) {
1034     if (failed(
1035             forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) {
1036       loadCSE(loadOp, opsToErase, domInfo);
1037     }
1038   });
1039 
1040   // Erase all load op's whose results were replaced with store fwd'ed ones.
1041   for (auto *op : opsToErase)
1042     op->erase();
1043   opsToErase.clear();
1044 
1045   // Walk all store's and perform unused store elimination
1046   f.walk([&](AffineWriteOpInterface storeOp) {
1047     findUnusedStore(storeOp, opsToErase, postDomInfo);
1048   });
1049   // Erase all store op's which don't impact the program
1050   for (auto *op : opsToErase)
1051     op->erase();
1052 
1053   // Check if the store fwd'ed memrefs are now left with only stores and can
1054   // thus be completely deleted. Note: the canonicalize pass should be able
1055   // to do this as well, but we'll do it here since we collected these anyway.
1056   for (auto memref : memrefsToErase) {
1057     // If the memref hasn't been alloc'ed in this function, skip.
1058     Operation *defOp = memref.getDefiningOp();
1059     if (!defOp || !isa<memref::AllocOp>(defOp))
1060       // TODO: if the memref was returned by a 'call' operation, we
1061       // could still erase it if the call had no side-effects.
1062       continue;
1063     if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
1064           return !isa<AffineWriteOpInterface, memref::DeallocOp>(ownerOp);
1065         }))
1066       continue;
1067 
1068     // Erase all stores, the dealloc, and the alloc on the memref.
1069     for (auto *user : llvm::make_early_inc_range(memref.getUsers()))
1070       user->erase();
1071     defOp->erase();
1072   }
1073 }
1074 
1075 // Perform the replacement in `op`.
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,Operation * op,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,bool allowNonDereferencingOps)1076 LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
1077                                              Operation *op,
1078                                              ArrayRef<Value> extraIndices,
1079                                              AffineMap indexRemap,
1080                                              ArrayRef<Value> extraOperands,
1081                                              ArrayRef<Value> symbolOperands,
1082                                              bool allowNonDereferencingOps) {
1083   unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
1084   (void)newMemRefRank; // unused in opt mode
1085   unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
1086   (void)oldMemRefRank; // unused in opt mode
1087   if (indexRemap) {
1088     assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1089            "symbolic operand count mismatch");
1090     assert(indexRemap.getNumInputs() ==
1091            extraOperands.size() + oldMemRefRank + symbolOperands.size());
1092     assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1093   } else {
1094     assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1095   }
1096 
1097   // Assert same elemental type.
1098   assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
1099          newMemRef.getType().cast<MemRefType>().getElementType());
1100 
1101   SmallVector<unsigned, 2> usePositions;
1102   for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
1103     if (opEntry.value() == oldMemRef)
1104       usePositions.push_back(opEntry.index());
1105   }
1106 
1107   // If memref doesn't appear, nothing to do.
1108   if (usePositions.empty())
1109     return success();
1110 
1111   if (usePositions.size() > 1) {
1112     // TODO: extend it for this case when needed (rare).
1113     assert(false && "multiple dereferencing uses in a single op not supported");
1114     return failure();
1115   }
1116 
1117   unsigned memRefOperandPos = usePositions.front();
1118 
1119   OpBuilder builder(op);
1120   // The following checks if op is dereferencing memref and performs the access
1121   // index rewrites.
1122   auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1123   if (!affMapAccInterface) {
1124     if (!allowNonDereferencingOps) {
1125       // Failure: memref used in a non-dereferencing context (potentially
1126       // escapes); no replacement in these cases unless allowNonDereferencingOps
1127       // is set.
1128       return failure();
1129     }
1130     op->setOperand(memRefOperandPos, newMemRef);
1131     return success();
1132   }
1133   // Perform index rewrites for the dereferencing op and then replace the op
1134   NamedAttribute oldMapAttrPair =
1135       affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
1136   AffineMap oldMap = oldMapAttrPair.getValue().cast<AffineMapAttr>().getValue();
1137   unsigned oldMapNumInputs = oldMap.getNumInputs();
1138   SmallVector<Value, 4> oldMapOperands(
1139       op->operand_begin() + memRefOperandPos + 1,
1140       op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
1141 
1142   // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
1143   SmallVector<Value, 4> oldMemRefOperands;
1144   SmallVector<Value, 4> affineApplyOps;
1145   oldMemRefOperands.reserve(oldMemRefRank);
1146   if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
1147     for (auto resultExpr : oldMap.getResults()) {
1148       auto singleResMap = AffineMap::get(oldMap.getNumDims(),
1149                                          oldMap.getNumSymbols(), resultExpr);
1150       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1151                                                 oldMapOperands);
1152       oldMemRefOperands.push_back(afOp);
1153       affineApplyOps.push_back(afOp);
1154     }
1155   } else {
1156     oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
1157   }
1158 
1159   // Construct new indices as a remap of the old ones if a remapping has been
1160   // provided. The indices of a memref come right after it, i.e.,
1161   // at position memRefOperandPos + 1.
1162   SmallVector<Value, 4> remapOperands;
1163   remapOperands.reserve(extraOperands.size() + oldMemRefRank +
1164                         symbolOperands.size());
1165   remapOperands.append(extraOperands.begin(), extraOperands.end());
1166   remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
1167   remapOperands.append(symbolOperands.begin(), symbolOperands.end());
1168 
1169   SmallVector<Value, 4> remapOutputs;
1170   remapOutputs.reserve(oldMemRefRank);
1171 
1172   if (indexRemap &&
1173       indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
1174     // Remapped indices.
1175     for (auto resultExpr : indexRemap.getResults()) {
1176       auto singleResMap = AffineMap::get(
1177           indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
1178       auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1179                                                 remapOperands);
1180       remapOutputs.push_back(afOp);
1181       affineApplyOps.push_back(afOp);
1182     }
1183   } else {
1184     // No remapping specified.
1185     remapOutputs.assign(remapOperands.begin(), remapOperands.end());
1186   }
1187 
1188   SmallVector<Value, 4> newMapOperands;
1189   newMapOperands.reserve(newMemRefRank);
1190 
1191   // Prepend 'extraIndices' in 'newMapOperands'.
1192   for (Value extraIndex : extraIndices) {
1193     assert(extraIndex.getDefiningOp()->getNumResults() == 1 &&
1194            "single result op's expected to generate these indices");
1195     assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
1196            "invalid memory op index");
1197     newMapOperands.push_back(extraIndex);
1198   }
1199 
1200   // Append 'remapOutputs' to 'newMapOperands'.
1201   newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
1202 
1203   // Create new fully composed AffineMap for new op to be created.
1204   assert(newMapOperands.size() == newMemRefRank);
1205   auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
1206   // TODO: Avoid creating/deleting temporary AffineApplyOps here.
1207   fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
1208   newMap = simplifyAffineMap(newMap);
1209   canonicalizeMapAndOperands(&newMap, &newMapOperands);
1210   // Remove any affine.apply's that became dead as a result of composition.
1211   for (Value value : affineApplyOps)
1212     if (value.use_empty())
1213       value.getDefiningOp()->erase();
1214 
1215   OperationState state(op->getLoc(), op->getName());
1216   // Construct the new operation using this memref.
1217   state.operands.reserve(op->getNumOperands() + extraIndices.size());
1218   // Insert the non-memref operands.
1219   state.operands.append(op->operand_begin(),
1220                         op->operand_begin() + memRefOperandPos);
1221   // Insert the new memref value.
1222   state.operands.push_back(newMemRef);
1223 
1224   // Insert the new memref map operands.
1225   state.operands.append(newMapOperands.begin(), newMapOperands.end());
1226 
1227   // Insert the remaining operands unmodified.
1228   state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
1229                             oldMapNumInputs,
1230                         op->operand_end());
1231 
1232   // Result types don't change. Both memref's are of the same elemental type.
1233   state.types.reserve(op->getNumResults());
1234   for (auto result : op->getResults())
1235     state.types.push_back(result.getType());
1236 
1237   // Add attribute for 'newMap', other Attributes do not change.
1238   auto newMapAttr = AffineMapAttr::get(newMap);
1239   for (auto namedAttr : op->getAttrs()) {
1240     if (namedAttr.getName() == oldMapAttrPair.getName())
1241       state.attributes.push_back({namedAttr.getName(), newMapAttr});
1242     else
1243       state.attributes.push_back(namedAttr);
1244   }
1245 
1246   // Create the new operation.
1247   auto *repOp = builder.create(state);
1248   op->replaceAllUsesWith(repOp);
1249   op->erase();
1250 
1251   return success();
1252 }
1253 
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,Operation * domOpFilter,Operation * postDomOpFilter,bool allowNonDereferencingOps,bool replaceInDeallocOp)1254 LogicalResult mlir::replaceAllMemRefUsesWith(
1255     Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
1256     AffineMap indexRemap, ArrayRef<Value> extraOperands,
1257     ArrayRef<Value> symbolOperands, Operation *domOpFilter,
1258     Operation *postDomOpFilter, bool allowNonDereferencingOps,
1259     bool replaceInDeallocOp) {
1260   unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
1261   (void)newMemRefRank; // unused in opt mode
1262   unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
1263   (void)oldMemRefRank;
1264   if (indexRemap) {
1265     assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1266            "symbol operand count mismatch");
1267     assert(indexRemap.getNumInputs() ==
1268            extraOperands.size() + oldMemRefRank + symbolOperands.size());
1269     assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1270   } else {
1271     assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1272   }
1273 
1274   // Assert same elemental type.
1275   assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
1276          newMemRef.getType().cast<MemRefType>().getElementType());
1277 
1278   std::unique_ptr<DominanceInfo> domInfo;
1279   std::unique_ptr<PostDominanceInfo> postDomInfo;
1280   if (domOpFilter)
1281     domInfo = std::make_unique<DominanceInfo>(
1282         domOpFilter->getParentOfType<func::FuncOp>());
1283 
1284   if (postDomOpFilter)
1285     postDomInfo = std::make_unique<PostDominanceInfo>(
1286         postDomOpFilter->getParentOfType<func::FuncOp>());
1287 
1288   // Walk all uses of old memref; collect ops to perform replacement. We use a
1289   // DenseSet since an operation could potentially have multiple uses of a
1290   // memref (although rare), and the replacement later is going to erase ops.
1291   DenseSet<Operation *> opsToReplace;
1292   for (auto *op : oldMemRef.getUsers()) {
1293     // Skip this use if it's not dominated by domOpFilter.
1294     if (domOpFilter && !domInfo->dominates(domOpFilter, op))
1295       continue;
1296 
1297     // Skip this use if it's not post-dominated by postDomOpFilter.
1298     if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op))
1299       continue;
1300 
1301     // Skip dealloc's - no replacement is necessary, and a memref replacement
1302     // at other uses doesn't hurt these dealloc's.
1303     if (isa<memref::DeallocOp>(op) && !replaceInDeallocOp)
1304       continue;
1305 
1306     // Check if the memref was used in a non-dereferencing context. It is fine
1307     // for the memref to be used in a non-dereferencing way outside of the
1308     // region where this replacement is happening.
1309     if (!isa<AffineMapAccessInterface>(*op)) {
1310       if (!allowNonDereferencingOps) {
1311         LLVM_DEBUG(llvm::dbgs()
1312                    << "Memref replacement failed: non-deferencing memref op: \n"
1313                    << *op << '\n');
1314         return failure();
1315       }
1316       // Non-dereferencing ops with the MemRefsNormalizable trait are
1317       // supported for replacement.
1318       if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) {
1319         LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a "
1320                                    "memrefs normalizable trait: \n"
1321                                 << *op << '\n');
1322         return failure();
1323       }
1324     }
1325 
1326     // We'll first collect and then replace --- since replacement erases the op
1327     // that has the use, and that op could be postDomFilter or domFilter itself!
1328     opsToReplace.insert(op);
1329   }
1330 
1331   for (auto *op : opsToReplace) {
1332     if (failed(replaceAllMemRefUsesWith(
1333             oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
1334             symbolOperands, allowNonDereferencingOps)))
1335       llvm_unreachable("memref replacement guaranteed to succeed here");
1336   }
1337 
1338   return success();
1339 }
1340 
1341 /// Given an operation, inserts one or more single result affine
1342 /// apply operations, results of which are exclusively used by this operation
1343 /// operation. The operands of these newly created affine apply ops are
1344 /// guaranteed to be loop iterators or terminal symbols of a function.
1345 ///
1346 /// Before
1347 ///
1348 /// affine.for %i = 0 to #map(%N)
1349 ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1350 ///   "send"(%idx, %A, ...)
1351 ///   "compute"(%idx)
1352 ///
1353 /// After
1354 ///
1355 /// affine.for %i = 0 to #map(%N)
1356 ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1357 ///   "send"(%idx, %A, ...)
1358 ///   %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
1359 ///   "compute"(%idx_)
1360 ///
1361 /// This allows applying different transformations on send and compute (for eg.
1362 /// different shifts/delays).
1363 ///
1364 /// Returns nullptr either if none of opInst's operands were the result of an
1365 /// affine.apply and thus there was no affine computation slice to create, or if
1366 /// all the affine.apply op's supplying operands to this opInst did not have any
1367 /// uses besides this opInst; otherwise returns the list of affine.apply
1368 /// operations created in output argument `sliceOps`.
createAffineComputationSlice(Operation * opInst,SmallVectorImpl<AffineApplyOp> * sliceOps)1369 void mlir::createAffineComputationSlice(
1370     Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
1371   // Collect all operands that are results of affine apply ops.
1372   SmallVector<Value, 4> subOperands;
1373   subOperands.reserve(opInst->getNumOperands());
1374   for (auto operand : opInst->getOperands())
1375     if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp()))
1376       subOperands.push_back(operand);
1377 
1378   // Gather sequence of AffineApplyOps reachable from 'subOperands'.
1379   SmallVector<Operation *, 4> affineApplyOps;
1380   getReachableAffineApplyOps(subOperands, affineApplyOps);
1381   // Skip transforming if there are no affine maps to compose.
1382   if (affineApplyOps.empty())
1383     return;
1384 
1385   // Check if all uses of the affine apply op's lie only in this op op, in
1386   // which case there would be nothing to do.
1387   bool localized = true;
1388   for (auto *op : affineApplyOps) {
1389     for (auto result : op->getResults()) {
1390       for (auto *user : result.getUsers()) {
1391         if (user != opInst) {
1392           localized = false;
1393           break;
1394         }
1395       }
1396     }
1397   }
1398   if (localized)
1399     return;
1400 
1401   OpBuilder builder(opInst);
1402   SmallVector<Value, 4> composedOpOperands(subOperands);
1403   auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
1404   fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
1405 
1406   // Create an affine.apply for each of the map results.
1407   sliceOps->reserve(composedMap.getNumResults());
1408   for (auto resultExpr : composedMap.getResults()) {
1409     auto singleResMap = AffineMap::get(composedMap.getNumDims(),
1410                                        composedMap.getNumSymbols(), resultExpr);
1411     sliceOps->push_back(builder.create<AffineApplyOp>(
1412         opInst->getLoc(), singleResMap, composedOpOperands));
1413   }
1414 
1415   // Construct the new operands that include the results from the composed
1416   // affine apply op above instead of existing ones (subOperands). So, they
1417   // differ from opInst's operands only for those operands in 'subOperands', for
1418   // which they will be replaced by the corresponding one from 'sliceOps'.
1419   SmallVector<Value, 4> newOperands(opInst->getOperands());
1420   for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
1421     // Replace the subOperands from among the new operands.
1422     unsigned j, f;
1423     for (j = 0, f = subOperands.size(); j < f; j++) {
1424       if (newOperands[i] == subOperands[j])
1425         break;
1426     }
1427     if (j < subOperands.size()) {
1428       newOperands[i] = (*sliceOps)[j];
1429     }
1430   }
1431   for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
1432     opInst->setOperand(idx, newOperands[idx]);
1433   }
1434 }
1435 
1436 /// Enum to set patterns of affine expr in tiled-layout map.
1437 /// TileFloorDiv: <dim expr> div <tile size>
1438 /// TileMod: <dim expr> mod <tile size>
1439 /// TileNone: None of the above
1440 /// Example:
1441 /// #tiled_2d_128x256 = affine_map<(d0, d1)
1442 ///            -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1443 /// "d0 div 128" and "d1 div 256" ==> TileFloorDiv
1444 /// "d0 mod 128" and "d1 mod 256" ==> TileMod
1445 enum TileExprPattern { TileFloorDiv, TileMod, TileNone };
1446 
1447 /// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions
1448 /// being floordiv'ed by respective tile sizes appeare in a mod with the same
1449 /// tile sizes, and no other expression involves those k dimensions. This
1450 /// function stores a vector of tuples (`tileSizePos`) including AffineExpr for
1451 /// tile size, positions of corresponding `floordiv` and `mod`. If it is not a
1452 /// tiled layout, an empty vector is returned.
getTileSizePos(AffineMap map,SmallVectorImpl<std::tuple<AffineExpr,unsigned,unsigned>> & tileSizePos)1453 static LogicalResult getTileSizePos(
1454     AffineMap map,
1455     SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) {
1456   // Create `floordivExprs` which is a vector of tuples including LHS and RHS of
1457   // `floordiv` and its position in `map` output.
1458   // Example: #tiled_2d_128x256 = affine_map<(d0, d1)
1459   //                -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1460   // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}.
1461   SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs;
1462   unsigned pos = 0;
1463   for (AffineExpr expr : map.getResults()) {
1464     if (expr.getKind() == AffineExprKind::FloorDiv) {
1465       AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
1466       if (binaryExpr.getRHS().isa<AffineConstantExpr>())
1467         floordivExprs.emplace_back(
1468             std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos));
1469     }
1470     pos++;
1471   }
1472   // Not tiled layout if `floordivExprs` is empty.
1473   if (floordivExprs.empty()) {
1474     tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1475     return success();
1476   }
1477 
1478   // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is
1479   // not tiled layout.
1480   for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) {
1481     AffineExpr floordivExprLHS = std::get<0>(fexpr);
1482     AffineExpr floordivExprRHS = std::get<1>(fexpr);
1483     unsigned floordivPos = std::get<2>(fexpr);
1484 
1485     // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS
1486     // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used
1487     // other expr, the map is not tiled layout. Example of non tiled layout:
1488     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)>
1489     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)>
1490     //   affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod
1491     //   256)>
1492     bool found = false;
1493     pos = 0;
1494     for (AffineExpr expr : map.getResults()) {
1495       bool notTiled = false;
1496       if (pos != floordivPos) {
1497         expr.walk([&](AffineExpr e) {
1498           if (e == floordivExprLHS) {
1499             if (expr.getKind() == AffineExprKind::Mod) {
1500               AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
1501               // If LHS and RHS of `mod` are the same with those of floordiv.
1502               if (floordivExprLHS == binaryExpr.getLHS() &&
1503                   floordivExprRHS == binaryExpr.getRHS()) {
1504                 // Save tile size (RHS of `mod`), and position of `floordiv` and
1505                 // `mod` if same expr with `mod` is not found yet.
1506                 if (!found) {
1507                   tileSizePos.emplace_back(
1508                       std::make_tuple(binaryExpr.getRHS(), floordivPos, pos));
1509                   found = true;
1510                 } else {
1511                   // Non tiled layout: Have multilpe `mod` with the same LHS.
1512                   // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1513                   // mod 256, d2 mod 256)>
1514                   notTiled = true;
1515                 }
1516               } else {
1517                 // Non tiled layout: RHS of `mod` is different from `floordiv`.
1518                 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1519                 // mod 128)>
1520                 notTiled = true;
1521               }
1522             } else {
1523               // Non tiled layout: LHS is the same, but not `mod`.
1524               // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1525               // floordiv 256)>
1526               notTiled = true;
1527             }
1528           }
1529         });
1530       }
1531       if (notTiled) {
1532         tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1533         return success();
1534       }
1535       pos++;
1536     }
1537   }
1538   return success();
1539 }
1540 
1541 /// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic
1542 /// after normalization. Dimensions that include dynamic dimensions in the map
1543 /// output will become dynamic dimensions. Return true if `dim` is dynamic
1544 /// dimension.
1545 ///
1546 /// Example:
1547 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1548 ///
1549 /// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic.
1550 /// memref<4x?xf32, #map0>  ==>  memref<4x?x?xf32>
1551 static bool
isNormalizedMemRefDynamicDim(unsigned dim,AffineMap layoutMap,SmallVectorImpl<unsigned> & inMemrefTypeDynDims,MLIRContext * context)1552 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
1553                              SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
1554                              MLIRContext *context) {
1555   bool isDynamicDim = false;
1556   AffineExpr expr = layoutMap.getResults()[dim];
1557   // Check if affine expr of the dimension includes dynamic dimension of input
1558   // memrefType.
1559   expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
1560     if (e.isa<AffineDimExpr>()) {
1561       for (unsigned dm : inMemrefTypeDynDims) {
1562         if (e == getAffineDimExpr(dm, context)) {
1563           isDynamicDim = true;
1564         }
1565       }
1566     }
1567   });
1568   return isDynamicDim;
1569 }
1570 
1571 /// Create affine expr to calculate dimension size for a tiled-layout map.
createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,TileExprPattern pat)1572 static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
1573                                                   TileExprPattern pat) {
1574   // Create map output for the patterns.
1575   // "floordiv <tile size>" ==> "ceildiv <tile size>"
1576   // "mod <tile size>" ==> "<tile size>"
1577   AffineExpr newMapOutput;
1578   AffineBinaryOpExpr binaryExpr = nullptr;
1579   switch (pat) {
1580   case TileExprPattern::TileMod:
1581     binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
1582     newMapOutput = binaryExpr.getRHS();
1583     break;
1584   case TileExprPattern::TileFloorDiv:
1585     binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
1586     newMapOutput = getAffineBinaryOpExpr(
1587         AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS());
1588     break;
1589   default:
1590     newMapOutput = oldMapOutput;
1591   }
1592   return newMapOutput;
1593 }
1594 
1595 /// Create new maps to calculate each dimension size of `newMemRefType`, and
1596 /// create `newDynamicSizes` from them by using AffineApplyOp.
1597 ///
1598 /// Steps for normalizing dynamic memrefs for a tiled layout map
1599 /// Example:
1600 ///    #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1601 ///    %0 = dim %arg0, %c1 :memref<4x?xf32>
1602 ///    %1 = alloc(%0) : memref<4x?xf32, #map0>
1603 ///
1604 /// (Before this function)
1605 /// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only
1606 /// single layout map is supported.
1607 ///
1608 /// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It
1609 /// is memref<4x?x?xf32> in the above example.
1610 ///
1611 /// (In this function)
1612 /// 3. Create new maps to calculate each dimension of the normalized memrefType
1613 /// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the
1614 /// dimension size can be calculated by replacing "floordiv <tile size>" with
1615 /// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>".
1616 /// - New map in the above example
1617 ///   #map0 = affine_map<(d0, d1) -> (d0)>
1618 ///   #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)>
1619 ///   #map2 = affine_map<(d0, d1) -> (32)>
1620 ///
1621 /// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp
1622 /// is used in dynamicSizes of new AllocOp.
1623 ///   %0 = dim %arg0, %c1 : memref<4x?xf32>
1624 ///   %c4 = arith.constant 4 : index
1625 ///   %1 = affine.apply #map1(%c4, %0)
1626 ///   %2 = affine.apply #map2(%c4, %0)
createNewDynamicSizes(MemRefType oldMemRefType,MemRefType newMemRefType,AffineMap map,memref::AllocOp * allocOp,OpBuilder b,SmallVectorImpl<Value> & newDynamicSizes)1627 static void createNewDynamicSizes(MemRefType oldMemRefType,
1628                                   MemRefType newMemRefType, AffineMap map,
1629                                   memref::AllocOp *allocOp, OpBuilder b,
1630                                   SmallVectorImpl<Value> &newDynamicSizes) {
1631   // Create new input for AffineApplyOp.
1632   SmallVector<Value, 4> inAffineApply;
1633   ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape();
1634   unsigned dynIdx = 0;
1635   for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) {
1636     if (oldMemRefShape[d] < 0) {
1637       // Use dynamicSizes of allocOp for dynamic dimension.
1638       inAffineApply.emplace_back(allocOp->getDynamicSizes()[dynIdx]);
1639       dynIdx++;
1640     } else {
1641       // Create ConstantOp for static dimension.
1642       Attribute constantAttr =
1643           b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
1644       inAffineApply.emplace_back(
1645           b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr));
1646     }
1647   }
1648 
1649   // Create new map to calculate each dimension size of new memref for each
1650   // original map output. Only for dynamic dimesion of `newMemRefType`.
1651   unsigned newDimIdx = 0;
1652   ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape();
1653   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1654   (void)getTileSizePos(map, tileSizePos);
1655   for (AffineExpr expr : map.getResults()) {
1656     if (newMemRefShape[newDimIdx] < 0) {
1657       // Create new maps to calculate each dimension size of new memref.
1658       enum TileExprPattern pat = TileExprPattern::TileNone;
1659       for (auto pos : tileSizePos) {
1660         if (newDimIdx == std::get<1>(pos))
1661           pat = TileExprPattern::TileFloorDiv;
1662         else if (newDimIdx == std::get<2>(pos))
1663           pat = TileExprPattern::TileMod;
1664       }
1665       AffineExpr newMapOutput = createDimSizeExprForTiledLayout(expr, pat);
1666       AffineMap newMap =
1667           AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput);
1668       Value affineApp =
1669           b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply);
1670       newDynamicSizes.emplace_back(affineApp);
1671     }
1672     newDimIdx++;
1673   }
1674 }
1675 
1676 // TODO: Currently works for static memrefs with a single layout map.
normalizeMemRef(memref::AllocOp * allocOp)1677 LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) {
1678   MemRefType memrefType = allocOp->getType();
1679   OpBuilder b(*allocOp);
1680 
1681   // Fetch a new memref type after normalizing the old memref to have an
1682   // identity map layout.
1683   MemRefType newMemRefType =
1684       normalizeMemRefType(memrefType, b, allocOp->getSymbolOperands().size());
1685   if (newMemRefType == memrefType)
1686     // Either memrefType already had an identity map or the map couldn't be
1687     // transformed to an identity map.
1688     return failure();
1689 
1690   Value oldMemRef = allocOp->getResult();
1691 
1692   SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
1693   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1694   memref::AllocOp newAlloc;
1695   // Check if `layoutMap` is a tiled layout. Only single layout map is
1696   // supported for normalizing dynamic memrefs.
1697   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1698   (void)getTileSizePos(layoutMap, tileSizePos);
1699   if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
1700     MemRefType oldMemRefType = oldMemRef.getType().cast<MemRefType>();
1701     SmallVector<Value, 4> newDynamicSizes;
1702     createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
1703                           newDynamicSizes);
1704     // Add the new dynamic sizes in new AllocOp.
1705     newAlloc =
1706         b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
1707                                   newDynamicSizes, allocOp->getAlignmentAttr());
1708   } else {
1709     newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
1710                                          allocOp->getAlignmentAttr());
1711   }
1712   // Replace all uses of the old memref.
1713   if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
1714                                       /*extraIndices=*/{},
1715                                       /*indexRemap=*/layoutMap,
1716                                       /*extraOperands=*/{},
1717                                       /*symbolOperands=*/symbolOperands,
1718                                       /*domOpFilter=*/nullptr,
1719                                       /*postDomOpFilter=*/nullptr,
1720                                       /*allowNonDereferencingOps=*/true))) {
1721     // If it failed (due to escapes for example), bail out.
1722     newAlloc.erase();
1723     return failure();
1724   }
1725   // Replace any uses of the original alloc op and erase it. All remaining uses
1726   // have to be dealloc's; RAMUW above would've failed otherwise.
1727   assert(llvm::all_of(oldMemRef.getUsers(), [](Operation *op) {
1728     return isa<memref::DeallocOp>(op);
1729   }));
1730   oldMemRef.replaceAllUsesWith(newAlloc);
1731   allocOp->erase();
1732   return success();
1733 }
1734 
normalizeMemRefType(MemRefType memrefType,OpBuilder b,unsigned numSymbolicOperands)1735 MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
1736                                      unsigned numSymbolicOperands) {
1737   unsigned rank = memrefType.getRank();
1738   if (rank == 0)
1739     return memrefType;
1740 
1741   if (memrefType.getLayout().isIdentity()) {
1742     // Either no maps is associated with this memref or this memref has
1743     // a trivial (identity) map.
1744     return memrefType;
1745   }
1746   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1747 
1748   // We don't do any checks for one-to-one'ness; we assume that it is
1749   // one-to-one.
1750 
1751   // Normalize only static memrefs and dynamic memrefs with a tiled-layout map
1752   // for now.
1753   // TODO: Normalize the other types of dynamic memrefs.
1754   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1755   (void)getTileSizePos(layoutMap, tileSizePos);
1756   if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
1757     return memrefType;
1758 
1759   // We have a single map that is not an identity map. Create a new memref
1760   // with the right shape and an identity layout map.
1761   ArrayRef<int64_t> shape = memrefType.getShape();
1762   // FlatAffineValueConstraint may later on use symbolicOperands.
1763   FlatAffineValueConstraints fac(rank, numSymbolicOperands);
1764   SmallVector<unsigned, 4> memrefTypeDynDims;
1765   for (unsigned d = 0; d < rank; ++d) {
1766     // Use constraint system only in static dimensions.
1767     if (shape[d] > 0) {
1768       fac.addBound(IntegerPolyhedron::LB, d, 0);
1769       fac.addBound(IntegerPolyhedron::UB, d, shape[d] - 1);
1770     } else {
1771       memrefTypeDynDims.emplace_back(d);
1772     }
1773   }
1774   // We compose this map with the original index (logical) space to derive
1775   // the upper bounds for the new index space.
1776   unsigned newRank = layoutMap.getNumResults();
1777   if (failed(fac.composeMatchingMap(layoutMap)))
1778     return memrefType;
1779   // TODO: Handle semi-affine maps.
1780   // Project out the old data dimensions.
1781   fac.projectOut(newRank, fac.getNumVars() - newRank - fac.getNumLocalVars());
1782   SmallVector<int64_t, 4> newShape(newRank);
1783   for (unsigned d = 0; d < newRank; ++d) {
1784     // Check if each dimension of normalized memrefType is dynamic.
1785     bool isDynDim = isNormalizedMemRefDynamicDim(
1786         d, layoutMap, memrefTypeDynDims, b.getContext());
1787     if (isDynDim) {
1788       newShape[d] = -1;
1789     } else {
1790       // The lower bound for the shape is always zero.
1791       auto ubConst = fac.getConstantBound(IntegerPolyhedron::UB, d);
1792       // For a static memref and an affine map with no symbols, this is
1793       // always bounded.
1794       assert(ubConst && "should always have an upper bound");
1795       if (ubConst.value() < 0)
1796         // This is due to an invalid map that maps to a negative space.
1797         return memrefType;
1798       // If dimension of new memrefType is dynamic, the value is -1.
1799       newShape[d] = ubConst.value() + 1;
1800     }
1801   }
1802 
1803   // Create the new memref type after trivializing the old layout map.
1804   MemRefType newMemRefType =
1805       MemRefType::Builder(memrefType)
1806           .setShape(newShape)
1807           .setLayout(AffineMapAttr::get(b.getMultiDimIdentityMap(newRank)));
1808 
1809   return newMemRefType;
1810 }
1811