1 //===- Utils.cpp ---- Utilities for affine dialect transformation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous transformation utilities for the Affine
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Affine/Utils.h"
15
16 #include "mlir/Dialect/Affine/Analysis/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/IR/AffineExprVisitor.h"
23 #include "mlir/IR/BlockAndValueMapping.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/IntegerSet.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28 #define DEBUG_TYPE "affine-utils"
29
30 using namespace mlir;
31 using namespace presburger;
32
33 namespace {
34 /// Visit affine expressions recursively and build the sequence of operations
35 /// that correspond to it. Visitation functions return an Value of the
36 /// expression subtree they visited or `nullptr` on error.
37 class AffineApplyExpander
38 : public AffineExprVisitor<AffineApplyExpander, Value> {
39 public:
40 /// This internal class expects arguments to be non-null, checks must be
41 /// performed at the call site.
AffineApplyExpander(OpBuilder & builder,ValueRange dimValues,ValueRange symbolValues,Location loc)42 AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
43 ValueRange symbolValues, Location loc)
44 : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
45 loc(loc) {}
46
47 template <typename OpTy>
buildBinaryExpr(AffineBinaryOpExpr expr)48 Value buildBinaryExpr(AffineBinaryOpExpr expr) {
49 auto lhs = visit(expr.getLHS());
50 auto rhs = visit(expr.getRHS());
51 if (!lhs || !rhs)
52 return nullptr;
53 auto op = builder.create<OpTy>(loc, lhs, rhs);
54 return op.getResult();
55 }
56
visitAddExpr(AffineBinaryOpExpr expr)57 Value visitAddExpr(AffineBinaryOpExpr expr) {
58 return buildBinaryExpr<arith::AddIOp>(expr);
59 }
60
visitMulExpr(AffineBinaryOpExpr expr)61 Value visitMulExpr(AffineBinaryOpExpr expr) {
62 return buildBinaryExpr<arith::MulIOp>(expr);
63 }
64
65 /// Euclidean modulo operation: negative RHS is not allowed.
66 /// Remainder of the euclidean integer division is always non-negative.
67 ///
68 /// Implemented as
69 ///
70 /// a mod b =
71 /// let remainder = srem a, b;
72 /// negative = a < 0 in
73 /// select negative, remainder + b, remainder.
visitModExpr(AffineBinaryOpExpr expr)74 Value visitModExpr(AffineBinaryOpExpr expr) {
75 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
76 if (!rhsConst) {
77 emitError(
78 loc,
79 "semi-affine expressions (modulo by non-const) are not supported");
80 return nullptr;
81 }
82 if (rhsConst.getValue() <= 0) {
83 emitError(loc, "modulo by non-positive value is not supported");
84 return nullptr;
85 }
86
87 auto lhs = visit(expr.getLHS());
88 auto rhs = visit(expr.getRHS());
89 assert(lhs && rhs && "unexpected affine expr lowering failure");
90
91 Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
92 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
93 Value isRemainderNegative = builder.create<arith::CmpIOp>(
94 loc, arith::CmpIPredicate::slt, remainder, zeroCst);
95 Value correctedRemainder =
96 builder.create<arith::AddIOp>(loc, remainder, rhs);
97 Value result = builder.create<arith::SelectOp>(
98 loc, isRemainderNegative, correctedRemainder, remainder);
99 return result;
100 }
101
102 /// Floor division operation (rounds towards negative infinity).
103 ///
104 /// For positive divisors, it can be implemented without branching and with a
105 /// single division operation as
106 ///
107 /// a floordiv b =
108 /// let negative = a < 0 in
109 /// let absolute = negative ? -a - 1 : a in
110 /// let quotient = absolute / b in
111 /// negative ? -quotient - 1 : quotient
visitFloorDivExpr(AffineBinaryOpExpr expr)112 Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
113 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
114 if (!rhsConst) {
115 emitError(
116 loc,
117 "semi-affine expressions (division by non-const) are not supported");
118 return nullptr;
119 }
120 if (rhsConst.getValue() <= 0) {
121 emitError(loc, "division by non-positive value is not supported");
122 return nullptr;
123 }
124
125 auto lhs = visit(expr.getLHS());
126 auto rhs = visit(expr.getRHS());
127 assert(lhs && rhs && "unexpected affine expr lowering failure");
128
129 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
130 Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1);
131 Value negative = builder.create<arith::CmpIOp>(
132 loc, arith::CmpIPredicate::slt, lhs, zeroCst);
133 Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
134 Value dividend =
135 builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs);
136 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
137 Value correctedQuotient =
138 builder.create<arith::SubIOp>(loc, noneCst, quotient);
139 Value result = builder.create<arith::SelectOp>(loc, negative,
140 correctedQuotient, quotient);
141 return result;
142 }
143
144 /// Ceiling division operation (rounds towards positive infinity).
145 ///
146 /// For positive divisors, it can be implemented without branching and with a
147 /// single division operation as
148 ///
149 /// a ceildiv b =
150 /// let negative = a <= 0 in
151 /// let absolute = negative ? -a : a - 1 in
152 /// let quotient = absolute / b in
153 /// negative ? -quotient : quotient + 1
visitCeilDivExpr(AffineBinaryOpExpr expr)154 Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
155 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
156 if (!rhsConst) {
157 emitError(loc) << "semi-affine expressions (division by non-const) are "
158 "not supported";
159 return nullptr;
160 }
161 if (rhsConst.getValue() <= 0) {
162 emitError(loc, "division by non-positive value is not supported");
163 return nullptr;
164 }
165 auto lhs = visit(expr.getLHS());
166 auto rhs = visit(expr.getRHS());
167 assert(lhs && rhs && "unexpected affine expr lowering failure");
168
169 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
170 Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1);
171 Value nonPositive = builder.create<arith::CmpIOp>(
172 loc, arith::CmpIPredicate::sle, lhs, zeroCst);
173 Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
174 Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
175 Value dividend =
176 builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented);
177 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
178 Value negatedQuotient =
179 builder.create<arith::SubIOp>(loc, zeroCst, quotient);
180 Value incrementedQuotient =
181 builder.create<arith::AddIOp>(loc, quotient, oneCst);
182 Value result = builder.create<arith::SelectOp>(
183 loc, nonPositive, negatedQuotient, incrementedQuotient);
184 return result;
185 }
186
visitConstantExpr(AffineConstantExpr expr)187 Value visitConstantExpr(AffineConstantExpr expr) {
188 auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
189 return op.getResult();
190 }
191
visitDimExpr(AffineDimExpr expr)192 Value visitDimExpr(AffineDimExpr expr) {
193 assert(expr.getPosition() < dimValues.size() &&
194 "affine dim position out of range");
195 return dimValues[expr.getPosition()];
196 }
197
visitSymbolExpr(AffineSymbolExpr expr)198 Value visitSymbolExpr(AffineSymbolExpr expr) {
199 assert(expr.getPosition() < symbolValues.size() &&
200 "symbol dim position out of range");
201 return symbolValues[expr.getPosition()];
202 }
203
204 private:
205 OpBuilder &builder;
206 ValueRange dimValues;
207 ValueRange symbolValues;
208
209 Location loc;
210 };
211 } // namespace
212
213 /// Create a sequence of operations that implement the `expr` applied to the
214 /// given dimension and symbol values.
expandAffineExpr(OpBuilder & builder,Location loc,AffineExpr expr,ValueRange dimValues,ValueRange symbolValues)215 mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc,
216 AffineExpr expr, ValueRange dimValues,
217 ValueRange symbolValues) {
218 return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
219 }
220
221 /// Create a sequence of operations that implement the `affineMap` applied to
222 /// the given `operands` (as it it were an AffineApplyOp).
expandAffineMap(OpBuilder & builder,Location loc,AffineMap affineMap,ValueRange operands)223 Optional<SmallVector<Value, 8>> mlir::expandAffineMap(OpBuilder &builder,
224 Location loc,
225 AffineMap affineMap,
226 ValueRange operands) {
227 auto numDims = affineMap.getNumDims();
228 auto expanded = llvm::to_vector<8>(
229 llvm::map_range(affineMap.getResults(),
230 [numDims, &builder, loc, operands](AffineExpr expr) {
231 return expandAffineExpr(builder, loc, expr,
232 operands.take_front(numDims),
233 operands.drop_front(numDims));
234 }));
235 if (llvm::all_of(expanded, [](Value v) { return v; }))
236 return expanded;
237 return None;
238 }
239
240 /// Promotes the `then` or the `else` block of `ifOp` (depending on whether
241 /// `elseBlock` is false or true) into `ifOp`'s containing block, and discards
242 /// the rest of the op.
promoteIfBlock(AffineIfOp ifOp,bool elseBlock)243 static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
244 if (elseBlock)
245 assert(ifOp.hasElse() && "else block expected");
246
247 Block *destBlock = ifOp->getBlock();
248 Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
249 destBlock->getOperations().splice(
250 Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(),
251 std::prev(srcBlock->end()));
252 ifOp.erase();
253 }
254
255 /// Returns the outermost affine.for/parallel op that the `ifOp` is invariant
256 /// on. The `ifOp` could be hoisted and placed right before such an operation.
257 /// This method assumes that the ifOp has been canonicalized (to be correct and
258 /// effective).
getOutermostInvariantForOp(AffineIfOp ifOp)259 static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
260 // Walk up the parents past all for op that this conditional is invariant on.
261 auto ifOperands = ifOp.getOperands();
262 auto *res = ifOp.getOperation();
263 while (!isa<func::FuncOp>(res->getParentOp())) {
264 auto *parentOp = res->getParentOp();
265 if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
266 if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
267 break;
268 } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
269 for (auto iv : parallelOp.getIVs())
270 if (llvm::is_contained(ifOperands, iv))
271 break;
272 } else if (!isa<AffineIfOp>(parentOp)) {
273 // Won't walk up past anything other than affine.for/if ops.
274 break;
275 }
276 // You can always hoist up past any affine.if ops.
277 res = parentOp;
278 }
279 return res;
280 }
281
282 /// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over
283 /// `hoistOverOp`. Returns the new hoisted op if any hoisting happened,
284 /// otherwise the same `ifOp`.
hoistAffineIfOp(AffineIfOp ifOp,Operation * hoistOverOp)285 static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
286 // No hoisting to do.
287 if (hoistOverOp == ifOp)
288 return ifOp;
289
290 // Create the hoisted 'if' first. Then, clone the op we are hoisting over for
291 // the else block. Then drop the else block of the original 'if' in the 'then'
292 // branch while promoting its then block, and analogously drop the 'then'
293 // block of the original 'if' from the 'else' branch while promoting its else
294 // block.
295 BlockAndValueMapping operandMap;
296 OpBuilder b(hoistOverOp);
297 auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
298 ifOp.getOperands(),
299 /*elseBlock=*/true);
300
301 // Create a clone of hoistOverOp to use for the else branch of the hoisted
302 // conditional. The else block may get optimized away if empty.
303 Operation *hoistOverOpClone = nullptr;
304 // We use this unique name to identify/find `ifOp`'s clone in the else
305 // version.
306 StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting");
307 operandMap.clear();
308 b.setInsertionPointAfter(hoistOverOp);
309 // We'll set an attribute to identify this op in a clone of this sub-tree.
310 ifOp->setAttr(idForIfOp, b.getBoolAttr(true));
311 hoistOverOpClone = b.clone(*hoistOverOp, operandMap);
312
313 // Promote the 'then' block of the original affine.if in the then version.
314 promoteIfBlock(ifOp, /*elseBlock=*/false);
315
316 // Move the then version to the hoisted if op's 'then' block.
317 auto *thenBlock = hoistedIfOp.getThenBlock();
318 thenBlock->getOperations().splice(thenBlock->begin(),
319 hoistOverOp->getBlock()->getOperations(),
320 Block::iterator(hoistOverOp));
321
322 // Find the clone of the original affine.if op in the else version.
323 AffineIfOp ifCloneInElse;
324 hoistOverOpClone->walk([&](AffineIfOp ifClone) {
325 if (!ifClone->getAttr(idForIfOp))
326 return WalkResult::advance();
327 ifCloneInElse = ifClone;
328 return WalkResult::interrupt();
329 });
330 assert(ifCloneInElse && "if op clone should exist");
331 // For the else block, promote the else block of the original 'if' if it had
332 // one; otherwise, the op itself is to be erased.
333 if (!ifCloneInElse.hasElse())
334 ifCloneInElse.erase();
335 else
336 promoteIfBlock(ifCloneInElse, /*elseBlock=*/true);
337
338 // Move the else version into the else block of the hoisted if op.
339 auto *elseBlock = hoistedIfOp.getElseBlock();
340 elseBlock->getOperations().splice(
341 elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
342 Block::iterator(hoistOverOpClone));
343
344 return hoistedIfOp;
345 }
346
347 LogicalResult
affineParallelize(AffineForOp forOp,ArrayRef<LoopReduction> parallelReductions)348 mlir::affineParallelize(AffineForOp forOp,
349 ArrayRef<LoopReduction> parallelReductions) {
350 // Fail early if there are iter arguments that are not reductions.
351 unsigned numReductions = parallelReductions.size();
352 if (numReductions != forOp.getNumIterOperands())
353 return failure();
354
355 Location loc = forOp.getLoc();
356 OpBuilder outsideBuilder(forOp);
357 AffineMap lowerBoundMap = forOp.getLowerBoundMap();
358 ValueRange lowerBoundOperands = forOp.getLowerBoundOperands();
359 AffineMap upperBoundMap = forOp.getUpperBoundMap();
360 ValueRange upperBoundOperands = forOp.getUpperBoundOperands();
361
362 // Creating empty 1-D affine.parallel op.
363 auto reducedValues = llvm::to_vector<4>(llvm::map_range(
364 parallelReductions, [](const LoopReduction &red) { return red.value; }));
365 auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
366 parallelReductions, [](const LoopReduction &red) { return red.kind; }));
367 AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
368 loc, ValueRange(reducedValues).getTypes(), reductionKinds,
369 llvm::makeArrayRef(lowerBoundMap), lowerBoundOperands,
370 llvm::makeArrayRef(upperBoundMap), upperBoundOperands,
371 llvm::makeArrayRef(forOp.getStep()));
372 // Steal the body of the old affine for op.
373 newPloop.getRegion().takeBody(forOp.getRegion());
374 Operation *yieldOp = &newPloop.getBody()->back();
375
376 // Handle the initial values of reductions because the parallel loop always
377 // starts from the neutral value.
378 SmallVector<Value> newResults;
379 newResults.reserve(numReductions);
380 for (unsigned i = 0; i < numReductions; ++i) {
381 Value init = forOp.getIterOperands()[i];
382 // This works because we are only handling single-op reductions at the
383 // moment. A switch on reduction kind or a mechanism to collect operations
384 // participating in the reduction will be necessary for multi-op reductions.
385 Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp();
386 assert(reductionOp && "yielded value is expected to be produced by an op");
387 outsideBuilder.getInsertionBlock()->getOperations().splice(
388 outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(),
389 reductionOp);
390 reductionOp->setOperands({init, newPloop->getResult(i)});
391 forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0));
392 }
393
394 // Update the loop terminator to yield reduced values bypassing the reduction
395 // operation itself (now moved outside of the loop) and erase the block
396 // arguments that correspond to reductions. Note that the loop always has one
397 // "main" induction variable whenc coming from a non-parallel for.
398 unsigned numIVs = 1;
399 yieldOp->setOperands(reducedValues);
400 newPloop.getBody()->eraseArguments(
401 llvm::to_vector<4>(llvm::seq<unsigned>(numIVs, numReductions + numIVs)));
402
403 forOp.erase();
404 return success();
405 }
406
407 // Returns success if any hoisting happened.
hoistAffineIfOp(AffineIfOp ifOp,bool * folded)408 LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
409 // Bail out early if the ifOp returns a result. TODO: Consider how to
410 // properly support this case.
411 if (ifOp.getNumResults() != 0)
412 return failure();
413
414 // Apply canonicalization patterns and folding - this is necessary for the
415 // hoisting check to be correct (operands should be composed), and to be more
416 // effective (no unused operands). Since the pattern rewriter's folding is
417 // entangled with application of patterns, we may fold/end up erasing the op,
418 // in which case we return with `folded` being set.
419 RewritePatternSet patterns(ifOp.getContext());
420 AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
421 bool erased;
422 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
423 (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased);
424 if (erased) {
425 if (folded)
426 *folded = true;
427 return failure();
428 }
429 if (folded)
430 *folded = false;
431
432 // The folding above should have ensured this, but the affine.if's
433 // canonicalization is missing composition of affine.applys into it.
434 assert(llvm::all_of(ifOp.getOperands(),
435 [](Value v) {
436 return isTopLevelValue(v) || isForInductionVar(v);
437 }) &&
438 "operands not composed");
439
440 // We are going hoist as high as possible.
441 // TODO: this could be customized in the future.
442 auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
443
444 AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
445 // Nothing to hoist over.
446 if (hoistedIfOp == ifOp)
447 return failure();
448
449 // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
450 // a sequence of affine.fors that are all perfectly nested).
451 (void)applyPatternsAndFoldGreedily(
452 hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
453 frozenPatterns);
454
455 return success();
456 }
457
458 // Return the min expr after replacing the given dim.
substWithMin(AffineExpr e,AffineExpr dim,AffineExpr min,AffineExpr max,bool positivePath)459 AffineExpr mlir::substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
460 AffineExpr max, bool positivePath) {
461 if (e == dim)
462 return positivePath ? min : max;
463 if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) {
464 AffineExpr lhs = bin.getLHS();
465 AffineExpr rhs = bin.getRHS();
466 if (bin.getKind() == mlir::AffineExprKind::Add)
467 return substWithMin(lhs, dim, min, max, positivePath) +
468 substWithMin(rhs, dim, min, max, positivePath);
469
470 auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>();
471 auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>();
472 if (c1 && c1.getValue() < 0)
473 return getAffineBinaryOpExpr(
474 bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath));
475 if (c2 && c2.getValue() < 0)
476 return getAffineBinaryOpExpr(
477 bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2);
478 return getAffineBinaryOpExpr(
479 bin.getKind(), substWithMin(lhs, dim, min, max, positivePath),
480 substWithMin(rhs, dim, min, max, positivePath));
481 }
482 return e;
483 }
484
normalizeAffineParallel(AffineParallelOp op)485 void mlir::normalizeAffineParallel(AffineParallelOp op) {
486 // Loops with min/max in bounds are not normalized at the moment.
487 if (op.hasMinMaxBounds())
488 return;
489
490 AffineMap lbMap = op.getLowerBoundsMap();
491 SmallVector<int64_t, 8> steps = op.getSteps();
492 // No need to do any work if the parallel op is already normalized.
493 bool isAlreadyNormalized =
494 llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
495 int64_t step = std::get<0>(tuple);
496 auto lbExpr =
497 std::get<1>(tuple).template dyn_cast<AffineConstantExpr>();
498 return lbExpr && lbExpr.getValue() == 0 && step == 1;
499 });
500 if (isAlreadyNormalized)
501 return;
502
503 AffineValueMap ranges;
504 AffineValueMap::difference(op.getUpperBoundsValueMap(),
505 op.getLowerBoundsValueMap(), &ranges);
506 auto builder = OpBuilder::atBlockBegin(op.getBody());
507 auto zeroExpr = builder.getAffineConstantExpr(0);
508 SmallVector<AffineExpr, 8> lbExprs;
509 SmallVector<AffineExpr, 8> ubExprs;
510 for (unsigned i = 0, e = steps.size(); i < e; ++i) {
511 int64_t step = steps[i];
512
513 // Adjust the lower bound to be 0.
514 lbExprs.push_back(zeroExpr);
515
516 // Adjust the upper bound expression: 'range / step'.
517 AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step);
518 ubExprs.push_back(ubExpr);
519
520 // Adjust the corresponding IV: 'lb + i * step'.
521 BlockArgument iv = op.getBody()->getArgument(i);
522 AffineExpr lbExpr = lbMap.getResult(i);
523 unsigned nDims = lbMap.getNumDims();
524 auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step;
525 auto map = AffineMap::get(/*dimCount=*/nDims + 1,
526 /*symbolCount=*/lbMap.getNumSymbols(), expr);
527
528 // Use an 'affine.apply' op that will be simplified later in subsequent
529 // canonicalizations.
530 OperandRange lbOperands = op.getLowerBoundsOperands();
531 OperandRange dimOperands = lbOperands.take_front(nDims);
532 OperandRange symbolOperands = lbOperands.drop_front(nDims);
533 SmallVector<Value, 8> applyOperands{dimOperands};
534 applyOperands.push_back(iv);
535 applyOperands.append(symbolOperands.begin(), symbolOperands.end());
536 auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
537 iv.replaceAllUsesExcept(apply, apply);
538 }
539
540 SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
541 op.setSteps(newSteps);
542 auto newLowerMap = AffineMap::get(
543 /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext());
544 op.setLowerBounds({}, newLowerMap);
545 auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(),
546 ubExprs, op.getContext());
547 op.setUpperBounds(ranges.getOperands(), newUpperMap);
548 }
549
550 /// Normalizes affine.for ops. If the affine.for op has only a single iteration
551 /// only then it is simply promoted, else it is normalized in the traditional
552 /// way, by converting the lower bound to zero and loop step to one. The upper
553 /// bound is set to the trip count of the loop. For now, original loops must
554 /// have lower bound with a single result only. There is no such restriction on
555 /// upper bounds.
normalizeAffineFor(AffineForOp op)556 LogicalResult mlir::normalizeAffineFor(AffineForOp op) {
557 if (succeeded(promoteIfSingleIteration(op)))
558 return success();
559
560 // Check if the forop is already normalized.
561 if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) &&
562 (op.getStep() == 1))
563 return success();
564
565 // Check if the lower bound has a single result only. Loops with a max lower
566 // bound can't be normalized without additional support like
567 // affine.execute_region's. If the lower bound does not have a single result
568 // then skip this op.
569 if (op.getLowerBoundMap().getNumResults() != 1)
570 return failure();
571
572 Location loc = op.getLoc();
573 OpBuilder opBuilder(op);
574 int64_t origLoopStep = op.getStep();
575
576 // Calculate upperBound for normalized loop.
577 SmallVector<Value, 4> ubOperands;
578 AffineBound lb = op.getLowerBound();
579 AffineBound ub = op.getUpperBound();
580 ubOperands.reserve(ub.getNumOperands() + lb.getNumOperands());
581 AffineMap origLbMap = lb.getMap();
582 AffineMap origUbMap = ub.getMap();
583
584 // Add dimension operands from upper/lower bound.
585 for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
586 ubOperands.push_back(ub.getOperand(j));
587 for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j)
588 ubOperands.push_back(lb.getOperand(j));
589
590 // Add symbol operands from upper/lower bound.
591 for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
592 ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
593 for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
594 ubOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
595
596 // Add original result expressions from lower/upper bound map.
597 SmallVector<AffineExpr, 1> origLbExprs(origLbMap.getResults().begin(),
598 origLbMap.getResults().end());
599 SmallVector<AffineExpr, 2> origUbExprs(origUbMap.getResults().begin(),
600 origUbMap.getResults().end());
601 SmallVector<AffineExpr, 4> newUbExprs;
602
603 // The original upperBound can have more than one result. For the new
604 // upperBound of this loop, take difference of all possible combinations of
605 // the ub results and lb result and ceildiv with the loop step. For e.g.,
606 //
607 // affine.for %i1 = 0 to min affine_map<(d0)[] -> (d0 + 32, 1024)>(%i0)
608 // will have an upperBound map as,
609 // affine_map<(d0)[] -> (((d0 + 32) - 0) ceildiv 1, (1024 - 0) ceildiv
610 // 1)>(%i0)
611 //
612 // Insert all combinations of upper/lower bound results.
613 for (unsigned i = 0, e = origUbExprs.size(); i < e; ++i) {
614 newUbExprs.push_back(
615 (origUbExprs[i] - origLbExprs[0]).ceilDiv(origLoopStep));
616 }
617
618 // Construct newUbMap.
619 AffineMap newUbMap =
620 AffineMap::get(origLbMap.getNumDims() + origUbMap.getNumDims(),
621 origLbMap.getNumSymbols() + origUbMap.getNumSymbols(),
622 newUbExprs, opBuilder.getContext());
623 canonicalizeMapAndOperands(&newUbMap, &ubOperands);
624
625 SmallVector<Value, 4> lbOperands(lb.getOperands().begin(),
626 lb.getOperands().begin() +
627 lb.getMap().getNumDims());
628
629 // Normalize the loop.
630 op.setUpperBound(ubOperands, newUbMap);
631 op.setLowerBound({}, opBuilder.getConstantAffineMap(0));
632 op.setStep(1);
633
634 // Calculate the Value of new loopIV. Create affine.apply for the value of
635 // the loopIV in normalized loop.
636 opBuilder.setInsertionPointToStart(op.getBody());
637 // Add an extra dim operand for loopIV.
638 lbOperands.push_back(op.getInductionVar());
639 // Add symbol operands from lower bound.
640 for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
641 lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
642
643 AffineExpr origIVExpr = opBuilder.getAffineDimExpr(lb.getMap().getNumDims());
644 AffineExpr newIVExpr = origIVExpr * origLoopStep + origLbMap.getResult(0);
645 AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1,
646 origLbMap.getNumSymbols(), newIVExpr);
647 canonicalizeMapAndOperands(&ivMap, &lbOperands);
648 Operation *newIV = opBuilder.create<AffineApplyOp>(loc, ivMap, lbOperands);
649 op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
650 return success();
651 }
652
653 /// Ensure that all operations that could be executed after `start`
654 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
655 /// between the operations) do not have the potential memory effect
656 /// `EffectType` on `memOp`. `memOp` is an operation that reads or writes to
657 /// a memref. For example, if `EffectType` is MemoryEffects::Write, this method
658 /// will check if there is no write to the memory between `start` and `memOp`
659 /// that would change the read within `memOp`.
660 template <typename EffectType, typename T>
hasNoInterveningEffect(Operation * start,T memOp)661 static bool hasNoInterveningEffect(Operation *start, T memOp) {
662 Value memref = memOp.getMemRef();
663 bool isOriginalAllocation = memref.getDefiningOp<memref::AllocaOp>() ||
664 memref.getDefiningOp<memref::AllocOp>();
665
666 // A boolean representing whether an intervening operation could have impacted
667 // memOp.
668 bool hasSideEffect = false;
669
670 // Check whether the effect on memOp can be caused by a given operation op.
671 std::function<void(Operation *)> checkOperation = [&](Operation *op) {
672 // If the effect has alreay been found, early exit,
673 if (hasSideEffect)
674 return;
675
676 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
677 SmallVector<MemoryEffects::EffectInstance, 1> effects;
678 memEffect.getEffects(effects);
679
680 bool opMayHaveEffect = false;
681 for (auto effect : effects) {
682 // If op causes EffectType on a potentially aliasing location for
683 // memOp, mark as having the effect.
684 if (isa<EffectType>(effect.getEffect())) {
685 if (isOriginalAllocation && effect.getValue() &&
686 (effect.getValue().getDefiningOp<memref::AllocaOp>() ||
687 effect.getValue().getDefiningOp<memref::AllocOp>())) {
688 if (effect.getValue() != memref)
689 continue;
690 }
691 opMayHaveEffect = true;
692 break;
693 }
694 }
695
696 if (!opMayHaveEffect)
697 return;
698
699 // If the side effect comes from an affine read or write, try to
700 // prove the side effecting `op` cannot reach `memOp`.
701 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
702 MemRefAccess srcAccess(op);
703 MemRefAccess destAccess(memOp);
704 // Dependence analysis is only correct if both ops operate on the same
705 // memref.
706 if (srcAccess.memref == destAccess.memref) {
707 FlatAffineValueConstraints dependenceConstraints;
708
709 // Number of loops containing the start op and the ending operation.
710 unsigned minSurroundingLoops =
711 getNumCommonSurroundingLoops(*start, *memOp);
712
713 // Number of loops containing the operation `op` which has the
714 // potential memory side effect and can occur on a path between
715 // `start` and `memOp`.
716 unsigned nsLoops = getNumCommonSurroundingLoops(*op, *memOp);
717
718 // For ease, let's consider the case that `op` is a store and we're
719 // looking for other potential stores (e.g `op`) that overwrite memory
720 // after `start`, and before being read in `memOp`. In this case, we
721 // only need to consider other potential stores with depth >
722 // minSurrounding loops since `start` would overwrite any store with a
723 // smaller number of surrounding loops before.
724 unsigned d;
725 for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
726 DependenceResult result = checkMemrefAccessDependence(
727 srcAccess, destAccess, d, &dependenceConstraints,
728 /*dependenceComponents=*/nullptr);
729 if (hasDependence(result)) {
730 hasSideEffect = true;
731 return;
732 }
733 }
734
735 // No side effect was seen, simply return.
736 return;
737 }
738 }
739 hasSideEffect = true;
740 return;
741 }
742
743 if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
744 // Recurse into the regions for this op and check whether the internal
745 // operations may have the side effect `EffectType` on memOp.
746 for (Region ®ion : op->getRegions())
747 for (Block &block : region)
748 for (Operation &op : block)
749 checkOperation(&op);
750 return;
751 }
752
753 // Otherwise, conservatively assume generic operations have the effect
754 // on the operation
755 hasSideEffect = true;
756 };
757
758 // Check all paths from ancestor op `parent` to the operation `to` for the
759 // effect. It is known that `to` must be contained within `parent`.
760 auto until = [&](Operation *parent, Operation *to) {
761 // TODO check only the paths from `parent` to `to`.
762 // Currently we fallback and check the entire parent op, rather than
763 // just the paths from the parent path, stopping after reaching `to`.
764 // This is conservatively correct, but could be made more aggressive.
765 assert(parent->isAncestor(to));
766 checkOperation(parent);
767 };
768
769 // Check for all paths from operation `from` to operation `untilOp` for the
770 // given memory effect.
771 std::function<void(Operation *, Operation *)> recur =
772 [&](Operation *from, Operation *untilOp) {
773 assert(
774 from->getParentRegion()->isAncestor(untilOp->getParentRegion()) &&
775 "Checking for side effect between two operations without a common "
776 "ancestor");
777
778 // If the operations are in different regions, recursively consider all
779 // path from `from` to the parent of `to` and all paths from the parent
780 // of `to` to `to`.
781 if (from->getParentRegion() != untilOp->getParentRegion()) {
782 recur(from, untilOp->getParentOp());
783 until(untilOp->getParentOp(), untilOp);
784 return;
785 }
786
787 // Now, assuming that `from` and `to` exist in the same region, perform
788 // a CFG traversal to check all the relevant operations.
789
790 // Additional blocks to consider.
791 SmallVector<Block *, 2> todoBlocks;
792 {
793 // First consider the parent block of `from` an check all operations
794 // after `from`.
795 for (auto iter = ++from->getIterator(), end = from->getBlock()->end();
796 iter != end && &*iter != untilOp; ++iter) {
797 checkOperation(&*iter);
798 }
799
800 // If the parent of `from` doesn't contain `to`, add the successors
801 // to the list of blocks to check.
802 if (untilOp->getBlock() != from->getBlock())
803 for (Block *succ : from->getBlock()->getSuccessors())
804 todoBlocks.push_back(succ);
805 }
806
807 SmallPtrSet<Block *, 4> done;
808 // Traverse the CFG until hitting `to`.
809 while (!todoBlocks.empty()) {
810 Block *blk = todoBlocks.pop_back_val();
811 if (done.count(blk))
812 continue;
813 done.insert(blk);
814 for (auto &op : *blk) {
815 if (&op == untilOp)
816 break;
817 checkOperation(&op);
818 if (&op == blk->getTerminator())
819 for (Block *succ : blk->getSuccessors())
820 todoBlocks.push_back(succ);
821 }
822 }
823 };
824 recur(start, memOp);
825 return !hasSideEffect;
826 }
827
828 /// Attempt to eliminate loadOp by replacing it with a value stored into memory
829 /// which the load is guaranteed to retrieve. This check involves three
830 /// components: 1) The store and load must be on the same location 2) The store
831 /// must dominate (and therefore must always occur prior to) the load 3) No
832 /// other operations will overwrite the memory loaded between the given load
833 /// and store. If such a value exists, the replaced `loadOp` will be added to
834 /// `loadOpsToErase` and its memref will be added to `memrefsToErase`.
forwardStoreToLoad(AffineReadOpInterface loadOp,SmallVectorImpl<Operation * > & loadOpsToErase,SmallPtrSetImpl<Value> & memrefsToErase,DominanceInfo & domInfo)835 static LogicalResult forwardStoreToLoad(
836 AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase,
837 SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo) {
838
839 // The store op candidate for forwarding that satisfies all conditions
840 // to replace the load, if any.
841 Operation *lastWriteStoreOp = nullptr;
842
843 for (auto *user : loadOp.getMemRef().getUsers()) {
844 auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
845 if (!storeOp)
846 continue;
847 MemRefAccess srcAccess(storeOp);
848 MemRefAccess destAccess(loadOp);
849
850 // 1. Check if the store and the load have mathematically equivalent
851 // affine access functions; this implies that they statically refer to the
852 // same single memref element. As an example this filters out cases like:
853 // store %A[%i0 + 1]
854 // load %A[%i0]
855 // store %A[%M]
856 // load %A[%N]
857 // Use the AffineValueMap difference based memref access equality checking.
858 if (srcAccess != destAccess)
859 continue;
860
861 // 2. The store has to dominate the load op to be candidate.
862 if (!domInfo.dominates(storeOp, loadOp))
863 continue;
864
865 // 3. Ensure there is no intermediate operation which could replace the
866 // value in memory.
867 if (!hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp))
868 continue;
869
870 // We now have a candidate for forwarding.
871 assert(lastWriteStoreOp == nullptr &&
872 "multiple simulataneous replacement stores");
873 lastWriteStoreOp = storeOp;
874 }
875
876 if (!lastWriteStoreOp)
877 return failure();
878
879 // Perform the actual store to load forwarding.
880 Value storeVal =
881 cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
882 // Check if 2 values have the same shape. This is needed for affine vector
883 // loads and stores.
884 if (storeVal.getType() != loadOp.getValue().getType())
885 return failure();
886 loadOp.getValue().replaceAllUsesWith(storeVal);
887 // Record the memref for a later sweep to optimize away.
888 memrefsToErase.insert(loadOp.getMemRef());
889 // Record this to erase later.
890 loadOpsToErase.push_back(loadOp);
891 return success();
892 }
893
894 // This attempts to find stores which have no impact on the final result.
895 // A writing op writeA will be eliminated if there exists an op writeB if
896 // 1) writeA and writeB have mathematically equivalent affine access functions.
897 // 2) writeB postdominates writeA.
898 // 3) There is no potential read between writeA and writeB.
findUnusedStore(AffineWriteOpInterface writeA,SmallVectorImpl<Operation * > & opsToErase,PostDominanceInfo & postDominanceInfo)899 static void findUnusedStore(AffineWriteOpInterface writeA,
900 SmallVectorImpl<Operation *> &opsToErase,
901 PostDominanceInfo &postDominanceInfo) {
902
903 for (Operation *user : writeA.getMemRef().getUsers()) {
904 // Only consider writing operations.
905 auto writeB = dyn_cast<AffineWriteOpInterface>(user);
906 if (!writeB)
907 continue;
908
909 // The operations must be distinct.
910 if (writeB == writeA)
911 continue;
912
913 // Both operations must lie in the same region.
914 if (writeB->getParentRegion() != writeA->getParentRegion())
915 continue;
916
917 // Both operations must write to the same memory.
918 MemRefAccess srcAccess(writeB);
919 MemRefAccess destAccess(writeA);
920
921 if (srcAccess != destAccess)
922 continue;
923
924 // writeB must postdominate writeA.
925 if (!postDominanceInfo.postDominates(writeB, writeA))
926 continue;
927
928 // There cannot be an operation which reads from memory between
929 // the two writes.
930 if (!hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB))
931 continue;
932
933 opsToErase.push_back(writeA);
934 break;
935 }
936 }
937
938 // The load to load forwarding / redundant load elimination is similar to the
939 // store to load forwarding.
940 // loadA will be be replaced with loadB if:
941 // 1) loadA and loadB have mathematically equivalent affine access functions.
942 // 2) loadB dominates loadA.
943 // 3) There is no write between loadA and loadB.
loadCSE(AffineReadOpInterface loadA,SmallVectorImpl<Operation * > & loadOpsToErase,DominanceInfo & domInfo)944 static void loadCSE(AffineReadOpInterface loadA,
945 SmallVectorImpl<Operation *> &loadOpsToErase,
946 DominanceInfo &domInfo) {
947 SmallVector<AffineReadOpInterface, 4> loadCandidates;
948 for (auto *user : loadA.getMemRef().getUsers()) {
949 auto loadB = dyn_cast<AffineReadOpInterface>(user);
950 if (!loadB || loadB == loadA)
951 continue;
952
953 MemRefAccess srcAccess(loadB);
954 MemRefAccess destAccess(loadA);
955
956 // 1. The accesses have to be to the same location.
957 if (srcAccess != destAccess) {
958 continue;
959 }
960
961 // 2. The store has to dominate the load op to be candidate.
962 if (!domInfo.dominates(loadB, loadA))
963 continue;
964
965 // 3. There is no write between loadA and loadB.
966 if (!hasNoInterveningEffect<MemoryEffects::Write>(loadB.getOperation(),
967 loadA))
968 continue;
969
970 // Check if two values have the same shape. This is needed for affine vector
971 // loads.
972 if (loadB.getValue().getType() != loadA.getValue().getType())
973 continue;
974
975 loadCandidates.push_back(loadB);
976 }
977
978 // Of the legal load candidates, use the one that dominates all others
979 // to minimize the subsequent need to loadCSE
980 Value loadB;
981 for (AffineReadOpInterface option : loadCandidates) {
982 if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) {
983 return depStore == option ||
984 domInfo.dominates(option.getOperation(),
985 depStore.getOperation());
986 })) {
987 loadB = option.getValue();
988 break;
989 }
990 }
991
992 if (loadB) {
993 loadA.getValue().replaceAllUsesWith(loadB);
994 // Record this to erase later.
995 loadOpsToErase.push_back(loadA);
996 }
997 }
998
999 // The store to load forwarding and load CSE rely on three conditions:
1000 //
1001 // 1) store/load providing a replacement value and load being replaced need to
1002 // have mathematically equivalent affine access functions (checked after full
1003 // composition of load/store operands); this implies that they access the same
1004 // single memref element for all iterations of the common surrounding loop,
1005 //
1006 // 2) the store/load op should dominate the load op,
1007 //
1008 // 3) no operation that may write to memory read by the load being replaced can
1009 // occur after executing the instruction (load or store) providing the
1010 // replacement value and before the load being replaced (thus potentially
1011 // allowing overwriting the memory read by the load).
1012 //
1013 // The above conditions are simple to check, sufficient, and powerful for most
1014 // cases in practice - they are sufficient, but not necessary --- since they
1015 // don't reason about loops that are guaranteed to execute at least once or
1016 // multiple sources to forward from.
1017 //
1018 // TODO: more forwarding can be done when support for
1019 // loop/conditional live-out SSA values is available.
1020 // TODO: do general dead store elimination for memref's. This pass
1021 // currently only eliminates the stores only if no other loads/uses (other
1022 // than dealloc) remain.
1023 //
affineScalarReplace(func::FuncOp f,DominanceInfo & domInfo,PostDominanceInfo & postDomInfo)1024 void mlir::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
1025 PostDominanceInfo &postDomInfo) {
1026 // Load op's whose results were replaced by those forwarded from stores.
1027 SmallVector<Operation *, 8> opsToErase;
1028
1029 // A list of memref's that are potentially dead / could be eliminated.
1030 SmallPtrSet<Value, 4> memrefsToErase;
1031
1032 // Walk all load's and perform store to load forwarding.
1033 f.walk([&](AffineReadOpInterface loadOp) {
1034 if (failed(
1035 forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) {
1036 loadCSE(loadOp, opsToErase, domInfo);
1037 }
1038 });
1039
1040 // Erase all load op's whose results were replaced with store fwd'ed ones.
1041 for (auto *op : opsToErase)
1042 op->erase();
1043 opsToErase.clear();
1044
1045 // Walk all store's and perform unused store elimination
1046 f.walk([&](AffineWriteOpInterface storeOp) {
1047 findUnusedStore(storeOp, opsToErase, postDomInfo);
1048 });
1049 // Erase all store op's which don't impact the program
1050 for (auto *op : opsToErase)
1051 op->erase();
1052
1053 // Check if the store fwd'ed memrefs are now left with only stores and can
1054 // thus be completely deleted. Note: the canonicalize pass should be able
1055 // to do this as well, but we'll do it here since we collected these anyway.
1056 for (auto memref : memrefsToErase) {
1057 // If the memref hasn't been alloc'ed in this function, skip.
1058 Operation *defOp = memref.getDefiningOp();
1059 if (!defOp || !isa<memref::AllocOp>(defOp))
1060 // TODO: if the memref was returned by a 'call' operation, we
1061 // could still erase it if the call had no side-effects.
1062 continue;
1063 if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
1064 return !isa<AffineWriteOpInterface, memref::DeallocOp>(ownerOp);
1065 }))
1066 continue;
1067
1068 // Erase all stores, the dealloc, and the alloc on the memref.
1069 for (auto *user : llvm::make_early_inc_range(memref.getUsers()))
1070 user->erase();
1071 defOp->erase();
1072 }
1073 }
1074
1075 // Perform the replacement in `op`.
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,Operation * op,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,bool allowNonDereferencingOps)1076 LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
1077 Operation *op,
1078 ArrayRef<Value> extraIndices,
1079 AffineMap indexRemap,
1080 ArrayRef<Value> extraOperands,
1081 ArrayRef<Value> symbolOperands,
1082 bool allowNonDereferencingOps) {
1083 unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
1084 (void)newMemRefRank; // unused in opt mode
1085 unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
1086 (void)oldMemRefRank; // unused in opt mode
1087 if (indexRemap) {
1088 assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1089 "symbolic operand count mismatch");
1090 assert(indexRemap.getNumInputs() ==
1091 extraOperands.size() + oldMemRefRank + symbolOperands.size());
1092 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1093 } else {
1094 assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1095 }
1096
1097 // Assert same elemental type.
1098 assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
1099 newMemRef.getType().cast<MemRefType>().getElementType());
1100
1101 SmallVector<unsigned, 2> usePositions;
1102 for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
1103 if (opEntry.value() == oldMemRef)
1104 usePositions.push_back(opEntry.index());
1105 }
1106
1107 // If memref doesn't appear, nothing to do.
1108 if (usePositions.empty())
1109 return success();
1110
1111 if (usePositions.size() > 1) {
1112 // TODO: extend it for this case when needed (rare).
1113 assert(false && "multiple dereferencing uses in a single op not supported");
1114 return failure();
1115 }
1116
1117 unsigned memRefOperandPos = usePositions.front();
1118
1119 OpBuilder builder(op);
1120 // The following checks if op is dereferencing memref and performs the access
1121 // index rewrites.
1122 auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
1123 if (!affMapAccInterface) {
1124 if (!allowNonDereferencingOps) {
1125 // Failure: memref used in a non-dereferencing context (potentially
1126 // escapes); no replacement in these cases unless allowNonDereferencingOps
1127 // is set.
1128 return failure();
1129 }
1130 op->setOperand(memRefOperandPos, newMemRef);
1131 return success();
1132 }
1133 // Perform index rewrites for the dereferencing op and then replace the op
1134 NamedAttribute oldMapAttrPair =
1135 affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
1136 AffineMap oldMap = oldMapAttrPair.getValue().cast<AffineMapAttr>().getValue();
1137 unsigned oldMapNumInputs = oldMap.getNumInputs();
1138 SmallVector<Value, 4> oldMapOperands(
1139 op->operand_begin() + memRefOperandPos + 1,
1140 op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
1141
1142 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
1143 SmallVector<Value, 4> oldMemRefOperands;
1144 SmallVector<Value, 4> affineApplyOps;
1145 oldMemRefOperands.reserve(oldMemRefRank);
1146 if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
1147 for (auto resultExpr : oldMap.getResults()) {
1148 auto singleResMap = AffineMap::get(oldMap.getNumDims(),
1149 oldMap.getNumSymbols(), resultExpr);
1150 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1151 oldMapOperands);
1152 oldMemRefOperands.push_back(afOp);
1153 affineApplyOps.push_back(afOp);
1154 }
1155 } else {
1156 oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
1157 }
1158
1159 // Construct new indices as a remap of the old ones if a remapping has been
1160 // provided. The indices of a memref come right after it, i.e.,
1161 // at position memRefOperandPos + 1.
1162 SmallVector<Value, 4> remapOperands;
1163 remapOperands.reserve(extraOperands.size() + oldMemRefRank +
1164 symbolOperands.size());
1165 remapOperands.append(extraOperands.begin(), extraOperands.end());
1166 remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
1167 remapOperands.append(symbolOperands.begin(), symbolOperands.end());
1168
1169 SmallVector<Value, 4> remapOutputs;
1170 remapOutputs.reserve(oldMemRefRank);
1171
1172 if (indexRemap &&
1173 indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
1174 // Remapped indices.
1175 for (auto resultExpr : indexRemap.getResults()) {
1176 auto singleResMap = AffineMap::get(
1177 indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
1178 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
1179 remapOperands);
1180 remapOutputs.push_back(afOp);
1181 affineApplyOps.push_back(afOp);
1182 }
1183 } else {
1184 // No remapping specified.
1185 remapOutputs.assign(remapOperands.begin(), remapOperands.end());
1186 }
1187
1188 SmallVector<Value, 4> newMapOperands;
1189 newMapOperands.reserve(newMemRefRank);
1190
1191 // Prepend 'extraIndices' in 'newMapOperands'.
1192 for (Value extraIndex : extraIndices) {
1193 assert(extraIndex.getDefiningOp()->getNumResults() == 1 &&
1194 "single result op's expected to generate these indices");
1195 assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
1196 "invalid memory op index");
1197 newMapOperands.push_back(extraIndex);
1198 }
1199
1200 // Append 'remapOutputs' to 'newMapOperands'.
1201 newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
1202
1203 // Create new fully composed AffineMap for new op to be created.
1204 assert(newMapOperands.size() == newMemRefRank);
1205 auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
1206 // TODO: Avoid creating/deleting temporary AffineApplyOps here.
1207 fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
1208 newMap = simplifyAffineMap(newMap);
1209 canonicalizeMapAndOperands(&newMap, &newMapOperands);
1210 // Remove any affine.apply's that became dead as a result of composition.
1211 for (Value value : affineApplyOps)
1212 if (value.use_empty())
1213 value.getDefiningOp()->erase();
1214
1215 OperationState state(op->getLoc(), op->getName());
1216 // Construct the new operation using this memref.
1217 state.operands.reserve(op->getNumOperands() + extraIndices.size());
1218 // Insert the non-memref operands.
1219 state.operands.append(op->operand_begin(),
1220 op->operand_begin() + memRefOperandPos);
1221 // Insert the new memref value.
1222 state.operands.push_back(newMemRef);
1223
1224 // Insert the new memref map operands.
1225 state.operands.append(newMapOperands.begin(), newMapOperands.end());
1226
1227 // Insert the remaining operands unmodified.
1228 state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
1229 oldMapNumInputs,
1230 op->operand_end());
1231
1232 // Result types don't change. Both memref's are of the same elemental type.
1233 state.types.reserve(op->getNumResults());
1234 for (auto result : op->getResults())
1235 state.types.push_back(result.getType());
1236
1237 // Add attribute for 'newMap', other Attributes do not change.
1238 auto newMapAttr = AffineMapAttr::get(newMap);
1239 for (auto namedAttr : op->getAttrs()) {
1240 if (namedAttr.getName() == oldMapAttrPair.getName())
1241 state.attributes.push_back({namedAttr.getName(), newMapAttr});
1242 else
1243 state.attributes.push_back(namedAttr);
1244 }
1245
1246 // Create the new operation.
1247 auto *repOp = builder.create(state);
1248 op->replaceAllUsesWith(repOp);
1249 op->erase();
1250
1251 return success();
1252 }
1253
replaceAllMemRefUsesWith(Value oldMemRef,Value newMemRef,ArrayRef<Value> extraIndices,AffineMap indexRemap,ArrayRef<Value> extraOperands,ArrayRef<Value> symbolOperands,Operation * domOpFilter,Operation * postDomOpFilter,bool allowNonDereferencingOps,bool replaceInDeallocOp)1254 LogicalResult mlir::replaceAllMemRefUsesWith(
1255 Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
1256 AffineMap indexRemap, ArrayRef<Value> extraOperands,
1257 ArrayRef<Value> symbolOperands, Operation *domOpFilter,
1258 Operation *postDomOpFilter, bool allowNonDereferencingOps,
1259 bool replaceInDeallocOp) {
1260 unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
1261 (void)newMemRefRank; // unused in opt mode
1262 unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
1263 (void)oldMemRefRank;
1264 if (indexRemap) {
1265 assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
1266 "symbol operand count mismatch");
1267 assert(indexRemap.getNumInputs() ==
1268 extraOperands.size() + oldMemRefRank + symbolOperands.size());
1269 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
1270 } else {
1271 assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
1272 }
1273
1274 // Assert same elemental type.
1275 assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
1276 newMemRef.getType().cast<MemRefType>().getElementType());
1277
1278 std::unique_ptr<DominanceInfo> domInfo;
1279 std::unique_ptr<PostDominanceInfo> postDomInfo;
1280 if (domOpFilter)
1281 domInfo = std::make_unique<DominanceInfo>(
1282 domOpFilter->getParentOfType<func::FuncOp>());
1283
1284 if (postDomOpFilter)
1285 postDomInfo = std::make_unique<PostDominanceInfo>(
1286 postDomOpFilter->getParentOfType<func::FuncOp>());
1287
1288 // Walk all uses of old memref; collect ops to perform replacement. We use a
1289 // DenseSet since an operation could potentially have multiple uses of a
1290 // memref (although rare), and the replacement later is going to erase ops.
1291 DenseSet<Operation *> opsToReplace;
1292 for (auto *op : oldMemRef.getUsers()) {
1293 // Skip this use if it's not dominated by domOpFilter.
1294 if (domOpFilter && !domInfo->dominates(domOpFilter, op))
1295 continue;
1296
1297 // Skip this use if it's not post-dominated by postDomOpFilter.
1298 if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op))
1299 continue;
1300
1301 // Skip dealloc's - no replacement is necessary, and a memref replacement
1302 // at other uses doesn't hurt these dealloc's.
1303 if (isa<memref::DeallocOp>(op) && !replaceInDeallocOp)
1304 continue;
1305
1306 // Check if the memref was used in a non-dereferencing context. It is fine
1307 // for the memref to be used in a non-dereferencing way outside of the
1308 // region where this replacement is happening.
1309 if (!isa<AffineMapAccessInterface>(*op)) {
1310 if (!allowNonDereferencingOps) {
1311 LLVM_DEBUG(llvm::dbgs()
1312 << "Memref replacement failed: non-deferencing memref op: \n"
1313 << *op << '\n');
1314 return failure();
1315 }
1316 // Non-dereferencing ops with the MemRefsNormalizable trait are
1317 // supported for replacement.
1318 if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) {
1319 LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a "
1320 "memrefs normalizable trait: \n"
1321 << *op << '\n');
1322 return failure();
1323 }
1324 }
1325
1326 // We'll first collect and then replace --- since replacement erases the op
1327 // that has the use, and that op could be postDomFilter or domFilter itself!
1328 opsToReplace.insert(op);
1329 }
1330
1331 for (auto *op : opsToReplace) {
1332 if (failed(replaceAllMemRefUsesWith(
1333 oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
1334 symbolOperands, allowNonDereferencingOps)))
1335 llvm_unreachable("memref replacement guaranteed to succeed here");
1336 }
1337
1338 return success();
1339 }
1340
1341 /// Given an operation, inserts one or more single result affine
1342 /// apply operations, results of which are exclusively used by this operation
1343 /// operation. The operands of these newly created affine apply ops are
1344 /// guaranteed to be loop iterators or terminal symbols of a function.
1345 ///
1346 /// Before
1347 ///
1348 /// affine.for %i = 0 to #map(%N)
1349 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1350 /// "send"(%idx, %A, ...)
1351 /// "compute"(%idx)
1352 ///
1353 /// After
1354 ///
1355 /// affine.for %i = 0 to #map(%N)
1356 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i)
1357 /// "send"(%idx, %A, ...)
1358 /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
1359 /// "compute"(%idx_)
1360 ///
1361 /// This allows applying different transformations on send and compute (for eg.
1362 /// different shifts/delays).
1363 ///
1364 /// Returns nullptr either if none of opInst's operands were the result of an
1365 /// affine.apply and thus there was no affine computation slice to create, or if
1366 /// all the affine.apply op's supplying operands to this opInst did not have any
1367 /// uses besides this opInst; otherwise returns the list of affine.apply
1368 /// operations created in output argument `sliceOps`.
createAffineComputationSlice(Operation * opInst,SmallVectorImpl<AffineApplyOp> * sliceOps)1369 void mlir::createAffineComputationSlice(
1370 Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
1371 // Collect all operands that are results of affine apply ops.
1372 SmallVector<Value, 4> subOperands;
1373 subOperands.reserve(opInst->getNumOperands());
1374 for (auto operand : opInst->getOperands())
1375 if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp()))
1376 subOperands.push_back(operand);
1377
1378 // Gather sequence of AffineApplyOps reachable from 'subOperands'.
1379 SmallVector<Operation *, 4> affineApplyOps;
1380 getReachableAffineApplyOps(subOperands, affineApplyOps);
1381 // Skip transforming if there are no affine maps to compose.
1382 if (affineApplyOps.empty())
1383 return;
1384
1385 // Check if all uses of the affine apply op's lie only in this op op, in
1386 // which case there would be nothing to do.
1387 bool localized = true;
1388 for (auto *op : affineApplyOps) {
1389 for (auto result : op->getResults()) {
1390 for (auto *user : result.getUsers()) {
1391 if (user != opInst) {
1392 localized = false;
1393 break;
1394 }
1395 }
1396 }
1397 }
1398 if (localized)
1399 return;
1400
1401 OpBuilder builder(opInst);
1402 SmallVector<Value, 4> composedOpOperands(subOperands);
1403 auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
1404 fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
1405
1406 // Create an affine.apply for each of the map results.
1407 sliceOps->reserve(composedMap.getNumResults());
1408 for (auto resultExpr : composedMap.getResults()) {
1409 auto singleResMap = AffineMap::get(composedMap.getNumDims(),
1410 composedMap.getNumSymbols(), resultExpr);
1411 sliceOps->push_back(builder.create<AffineApplyOp>(
1412 opInst->getLoc(), singleResMap, composedOpOperands));
1413 }
1414
1415 // Construct the new operands that include the results from the composed
1416 // affine apply op above instead of existing ones (subOperands). So, they
1417 // differ from opInst's operands only for those operands in 'subOperands', for
1418 // which they will be replaced by the corresponding one from 'sliceOps'.
1419 SmallVector<Value, 4> newOperands(opInst->getOperands());
1420 for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
1421 // Replace the subOperands from among the new operands.
1422 unsigned j, f;
1423 for (j = 0, f = subOperands.size(); j < f; j++) {
1424 if (newOperands[i] == subOperands[j])
1425 break;
1426 }
1427 if (j < subOperands.size()) {
1428 newOperands[i] = (*sliceOps)[j];
1429 }
1430 }
1431 for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
1432 opInst->setOperand(idx, newOperands[idx]);
1433 }
1434 }
1435
1436 /// Enum to set patterns of affine expr in tiled-layout map.
1437 /// TileFloorDiv: <dim expr> div <tile size>
1438 /// TileMod: <dim expr> mod <tile size>
1439 /// TileNone: None of the above
1440 /// Example:
1441 /// #tiled_2d_128x256 = affine_map<(d0, d1)
1442 /// -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1443 /// "d0 div 128" and "d1 div 256" ==> TileFloorDiv
1444 /// "d0 mod 128" and "d1 mod 256" ==> TileMod
1445 enum TileExprPattern { TileFloorDiv, TileMod, TileNone };
1446
1447 /// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions
1448 /// being floordiv'ed by respective tile sizes appeare in a mod with the same
1449 /// tile sizes, and no other expression involves those k dimensions. This
1450 /// function stores a vector of tuples (`tileSizePos`) including AffineExpr for
1451 /// tile size, positions of corresponding `floordiv` and `mod`. If it is not a
1452 /// tiled layout, an empty vector is returned.
getTileSizePos(AffineMap map,SmallVectorImpl<std::tuple<AffineExpr,unsigned,unsigned>> & tileSizePos)1453 static LogicalResult getTileSizePos(
1454 AffineMap map,
1455 SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) {
1456 // Create `floordivExprs` which is a vector of tuples including LHS and RHS of
1457 // `floordiv` and its position in `map` output.
1458 // Example: #tiled_2d_128x256 = affine_map<(d0, d1)
1459 // -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)>
1460 // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}.
1461 SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs;
1462 unsigned pos = 0;
1463 for (AffineExpr expr : map.getResults()) {
1464 if (expr.getKind() == AffineExprKind::FloorDiv) {
1465 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
1466 if (binaryExpr.getRHS().isa<AffineConstantExpr>())
1467 floordivExprs.emplace_back(
1468 std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos));
1469 }
1470 pos++;
1471 }
1472 // Not tiled layout if `floordivExprs` is empty.
1473 if (floordivExprs.empty()) {
1474 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1475 return success();
1476 }
1477
1478 // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is
1479 // not tiled layout.
1480 for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) {
1481 AffineExpr floordivExprLHS = std::get<0>(fexpr);
1482 AffineExpr floordivExprRHS = std::get<1>(fexpr);
1483 unsigned floordivPos = std::get<2>(fexpr);
1484
1485 // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS
1486 // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used
1487 // other expr, the map is not tiled layout. Example of non tiled layout:
1488 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)>
1489 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)>
1490 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod
1491 // 256)>
1492 bool found = false;
1493 pos = 0;
1494 for (AffineExpr expr : map.getResults()) {
1495 bool notTiled = false;
1496 if (pos != floordivPos) {
1497 expr.walk([&](AffineExpr e) {
1498 if (e == floordivExprLHS) {
1499 if (expr.getKind() == AffineExprKind::Mod) {
1500 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
1501 // If LHS and RHS of `mod` are the same with those of floordiv.
1502 if (floordivExprLHS == binaryExpr.getLHS() &&
1503 floordivExprRHS == binaryExpr.getRHS()) {
1504 // Save tile size (RHS of `mod`), and position of `floordiv` and
1505 // `mod` if same expr with `mod` is not found yet.
1506 if (!found) {
1507 tileSizePos.emplace_back(
1508 std::make_tuple(binaryExpr.getRHS(), floordivPos, pos));
1509 found = true;
1510 } else {
1511 // Non tiled layout: Have multilpe `mod` with the same LHS.
1512 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1513 // mod 256, d2 mod 256)>
1514 notTiled = true;
1515 }
1516 } else {
1517 // Non tiled layout: RHS of `mod` is different from `floordiv`.
1518 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1519 // mod 128)>
1520 notTiled = true;
1521 }
1522 } else {
1523 // Non tiled layout: LHS is the same, but not `mod`.
1524 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2
1525 // floordiv 256)>
1526 notTiled = true;
1527 }
1528 }
1529 });
1530 }
1531 if (notTiled) {
1532 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
1533 return success();
1534 }
1535 pos++;
1536 }
1537 }
1538 return success();
1539 }
1540
1541 /// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic
1542 /// after normalization. Dimensions that include dynamic dimensions in the map
1543 /// output will become dynamic dimensions. Return true if `dim` is dynamic
1544 /// dimension.
1545 ///
1546 /// Example:
1547 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1548 ///
1549 /// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic.
1550 /// memref<4x?xf32, #map0> ==> memref<4x?x?xf32>
1551 static bool
isNormalizedMemRefDynamicDim(unsigned dim,AffineMap layoutMap,SmallVectorImpl<unsigned> & inMemrefTypeDynDims,MLIRContext * context)1552 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
1553 SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
1554 MLIRContext *context) {
1555 bool isDynamicDim = false;
1556 AffineExpr expr = layoutMap.getResults()[dim];
1557 // Check if affine expr of the dimension includes dynamic dimension of input
1558 // memrefType.
1559 expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
1560 if (e.isa<AffineDimExpr>()) {
1561 for (unsigned dm : inMemrefTypeDynDims) {
1562 if (e == getAffineDimExpr(dm, context)) {
1563 isDynamicDim = true;
1564 }
1565 }
1566 }
1567 });
1568 return isDynamicDim;
1569 }
1570
1571 /// Create affine expr to calculate dimension size for a tiled-layout map.
createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,TileExprPattern pat)1572 static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
1573 TileExprPattern pat) {
1574 // Create map output for the patterns.
1575 // "floordiv <tile size>" ==> "ceildiv <tile size>"
1576 // "mod <tile size>" ==> "<tile size>"
1577 AffineExpr newMapOutput;
1578 AffineBinaryOpExpr binaryExpr = nullptr;
1579 switch (pat) {
1580 case TileExprPattern::TileMod:
1581 binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
1582 newMapOutput = binaryExpr.getRHS();
1583 break;
1584 case TileExprPattern::TileFloorDiv:
1585 binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>();
1586 newMapOutput = getAffineBinaryOpExpr(
1587 AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS());
1588 break;
1589 default:
1590 newMapOutput = oldMapOutput;
1591 }
1592 return newMapOutput;
1593 }
1594
1595 /// Create new maps to calculate each dimension size of `newMemRefType`, and
1596 /// create `newDynamicSizes` from them by using AffineApplyOp.
1597 ///
1598 /// Steps for normalizing dynamic memrefs for a tiled layout map
1599 /// Example:
1600 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)>
1601 /// %0 = dim %arg0, %c1 :memref<4x?xf32>
1602 /// %1 = alloc(%0) : memref<4x?xf32, #map0>
1603 ///
1604 /// (Before this function)
1605 /// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only
1606 /// single layout map is supported.
1607 ///
1608 /// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It
1609 /// is memref<4x?x?xf32> in the above example.
1610 ///
1611 /// (In this function)
1612 /// 3. Create new maps to calculate each dimension of the normalized memrefType
1613 /// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the
1614 /// dimension size can be calculated by replacing "floordiv <tile size>" with
1615 /// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>".
1616 /// - New map in the above example
1617 /// #map0 = affine_map<(d0, d1) -> (d0)>
1618 /// #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)>
1619 /// #map2 = affine_map<(d0, d1) -> (32)>
1620 ///
1621 /// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp
1622 /// is used in dynamicSizes of new AllocOp.
1623 /// %0 = dim %arg0, %c1 : memref<4x?xf32>
1624 /// %c4 = arith.constant 4 : index
1625 /// %1 = affine.apply #map1(%c4, %0)
1626 /// %2 = affine.apply #map2(%c4, %0)
createNewDynamicSizes(MemRefType oldMemRefType,MemRefType newMemRefType,AffineMap map,memref::AllocOp * allocOp,OpBuilder b,SmallVectorImpl<Value> & newDynamicSizes)1627 static void createNewDynamicSizes(MemRefType oldMemRefType,
1628 MemRefType newMemRefType, AffineMap map,
1629 memref::AllocOp *allocOp, OpBuilder b,
1630 SmallVectorImpl<Value> &newDynamicSizes) {
1631 // Create new input for AffineApplyOp.
1632 SmallVector<Value, 4> inAffineApply;
1633 ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape();
1634 unsigned dynIdx = 0;
1635 for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) {
1636 if (oldMemRefShape[d] < 0) {
1637 // Use dynamicSizes of allocOp for dynamic dimension.
1638 inAffineApply.emplace_back(allocOp->getDynamicSizes()[dynIdx]);
1639 dynIdx++;
1640 } else {
1641 // Create ConstantOp for static dimension.
1642 Attribute constantAttr =
1643 b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
1644 inAffineApply.emplace_back(
1645 b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr));
1646 }
1647 }
1648
1649 // Create new map to calculate each dimension size of new memref for each
1650 // original map output. Only for dynamic dimesion of `newMemRefType`.
1651 unsigned newDimIdx = 0;
1652 ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape();
1653 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1654 (void)getTileSizePos(map, tileSizePos);
1655 for (AffineExpr expr : map.getResults()) {
1656 if (newMemRefShape[newDimIdx] < 0) {
1657 // Create new maps to calculate each dimension size of new memref.
1658 enum TileExprPattern pat = TileExprPattern::TileNone;
1659 for (auto pos : tileSizePos) {
1660 if (newDimIdx == std::get<1>(pos))
1661 pat = TileExprPattern::TileFloorDiv;
1662 else if (newDimIdx == std::get<2>(pos))
1663 pat = TileExprPattern::TileMod;
1664 }
1665 AffineExpr newMapOutput = createDimSizeExprForTiledLayout(expr, pat);
1666 AffineMap newMap =
1667 AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput);
1668 Value affineApp =
1669 b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply);
1670 newDynamicSizes.emplace_back(affineApp);
1671 }
1672 newDimIdx++;
1673 }
1674 }
1675
1676 // TODO: Currently works for static memrefs with a single layout map.
normalizeMemRef(memref::AllocOp * allocOp)1677 LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) {
1678 MemRefType memrefType = allocOp->getType();
1679 OpBuilder b(*allocOp);
1680
1681 // Fetch a new memref type after normalizing the old memref to have an
1682 // identity map layout.
1683 MemRefType newMemRefType =
1684 normalizeMemRefType(memrefType, b, allocOp->getSymbolOperands().size());
1685 if (newMemRefType == memrefType)
1686 // Either memrefType already had an identity map or the map couldn't be
1687 // transformed to an identity map.
1688 return failure();
1689
1690 Value oldMemRef = allocOp->getResult();
1691
1692 SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
1693 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1694 memref::AllocOp newAlloc;
1695 // Check if `layoutMap` is a tiled layout. Only single layout map is
1696 // supported for normalizing dynamic memrefs.
1697 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1698 (void)getTileSizePos(layoutMap, tileSizePos);
1699 if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
1700 MemRefType oldMemRefType = oldMemRef.getType().cast<MemRefType>();
1701 SmallVector<Value, 4> newDynamicSizes;
1702 createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
1703 newDynamicSizes);
1704 // Add the new dynamic sizes in new AllocOp.
1705 newAlloc =
1706 b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
1707 newDynamicSizes, allocOp->getAlignmentAttr());
1708 } else {
1709 newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
1710 allocOp->getAlignmentAttr());
1711 }
1712 // Replace all uses of the old memref.
1713 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
1714 /*extraIndices=*/{},
1715 /*indexRemap=*/layoutMap,
1716 /*extraOperands=*/{},
1717 /*symbolOperands=*/symbolOperands,
1718 /*domOpFilter=*/nullptr,
1719 /*postDomOpFilter=*/nullptr,
1720 /*allowNonDereferencingOps=*/true))) {
1721 // If it failed (due to escapes for example), bail out.
1722 newAlloc.erase();
1723 return failure();
1724 }
1725 // Replace any uses of the original alloc op and erase it. All remaining uses
1726 // have to be dealloc's; RAMUW above would've failed otherwise.
1727 assert(llvm::all_of(oldMemRef.getUsers(), [](Operation *op) {
1728 return isa<memref::DeallocOp>(op);
1729 }));
1730 oldMemRef.replaceAllUsesWith(newAlloc);
1731 allocOp->erase();
1732 return success();
1733 }
1734
normalizeMemRefType(MemRefType memrefType,OpBuilder b,unsigned numSymbolicOperands)1735 MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
1736 unsigned numSymbolicOperands) {
1737 unsigned rank = memrefType.getRank();
1738 if (rank == 0)
1739 return memrefType;
1740
1741 if (memrefType.getLayout().isIdentity()) {
1742 // Either no maps is associated with this memref or this memref has
1743 // a trivial (identity) map.
1744 return memrefType;
1745 }
1746 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
1747
1748 // We don't do any checks for one-to-one'ness; we assume that it is
1749 // one-to-one.
1750
1751 // Normalize only static memrefs and dynamic memrefs with a tiled-layout map
1752 // for now.
1753 // TODO: Normalize the other types of dynamic memrefs.
1754 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
1755 (void)getTileSizePos(layoutMap, tileSizePos);
1756 if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
1757 return memrefType;
1758
1759 // We have a single map that is not an identity map. Create a new memref
1760 // with the right shape and an identity layout map.
1761 ArrayRef<int64_t> shape = memrefType.getShape();
1762 // FlatAffineValueConstraint may later on use symbolicOperands.
1763 FlatAffineValueConstraints fac(rank, numSymbolicOperands);
1764 SmallVector<unsigned, 4> memrefTypeDynDims;
1765 for (unsigned d = 0; d < rank; ++d) {
1766 // Use constraint system only in static dimensions.
1767 if (shape[d] > 0) {
1768 fac.addBound(IntegerPolyhedron::LB, d, 0);
1769 fac.addBound(IntegerPolyhedron::UB, d, shape[d] - 1);
1770 } else {
1771 memrefTypeDynDims.emplace_back(d);
1772 }
1773 }
1774 // We compose this map with the original index (logical) space to derive
1775 // the upper bounds for the new index space.
1776 unsigned newRank = layoutMap.getNumResults();
1777 if (failed(fac.composeMatchingMap(layoutMap)))
1778 return memrefType;
1779 // TODO: Handle semi-affine maps.
1780 // Project out the old data dimensions.
1781 fac.projectOut(newRank, fac.getNumVars() - newRank - fac.getNumLocalVars());
1782 SmallVector<int64_t, 4> newShape(newRank);
1783 for (unsigned d = 0; d < newRank; ++d) {
1784 // Check if each dimension of normalized memrefType is dynamic.
1785 bool isDynDim = isNormalizedMemRefDynamicDim(
1786 d, layoutMap, memrefTypeDynDims, b.getContext());
1787 if (isDynDim) {
1788 newShape[d] = -1;
1789 } else {
1790 // The lower bound for the shape is always zero.
1791 auto ubConst = fac.getConstantBound(IntegerPolyhedron::UB, d);
1792 // For a static memref and an affine map with no symbols, this is
1793 // always bounded.
1794 assert(ubConst && "should always have an upper bound");
1795 if (ubConst.value() < 0)
1796 // This is due to an invalid map that maps to a negative space.
1797 return memrefType;
1798 // If dimension of new memrefType is dynamic, the value is -1.
1799 newShape[d] = ubConst.value() + 1;
1800 }
1801 }
1802
1803 // Create the new memref type after trivializing the old layout map.
1804 MemRefType newMemRefType =
1805 MemRefType::Builder(memrefType)
1806 .setShape(newShape)
1807 .setLayout(AffineMapAttr::get(b.getMultiDimIdentityMap(newRank)));
1808
1809 return newMemRefType;
1810 }
1811