1 //===- Utils.cpp ---- Utilities for affine dialect transformation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous transformation utilities for the Affine 10 // dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/Utils.h" 15 16 #include "mlir/Dialect/Affine/Analysis/Utils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/IR/AffineExprVisitor.h" 23 #include "mlir/IR/BlockAndValueMapping.h" 24 #include "mlir/IR/Dominance.h" 25 #include "mlir/IR/IntegerSet.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 28 #define DEBUG_TYPE "affine-utils" 29 30 using namespace mlir; 31 using namespace presburger; 32 33 namespace { 34 /// Visit affine expressions recursively and build the sequence of operations 35 /// that correspond to it. Visitation functions return an Value of the 36 /// expression subtree they visited or `nullptr` on error. 37 class AffineApplyExpander 38 : public AffineExprVisitor<AffineApplyExpander, Value> { 39 public: 40 /// This internal class expects arguments to be non-null, checks must be 41 /// performed at the call site. 42 AffineApplyExpander(OpBuilder &builder, ValueRange dimValues, 43 ValueRange symbolValues, Location loc) 44 : builder(builder), dimValues(dimValues), symbolValues(symbolValues), 45 loc(loc) {} 46 47 template <typename OpTy> 48 Value buildBinaryExpr(AffineBinaryOpExpr expr) { 49 auto lhs = visit(expr.getLHS()); 50 auto rhs = visit(expr.getRHS()); 51 if (!lhs || !rhs) 52 return nullptr; 53 auto op = builder.create<OpTy>(loc, lhs, rhs); 54 return op.getResult(); 55 } 56 57 Value visitAddExpr(AffineBinaryOpExpr expr) { 58 return buildBinaryExpr<arith::AddIOp>(expr); 59 } 60 61 Value visitMulExpr(AffineBinaryOpExpr expr) { 62 return buildBinaryExpr<arith::MulIOp>(expr); 63 } 64 65 /// Euclidean modulo operation: negative RHS is not allowed. 66 /// Remainder of the euclidean integer division is always non-negative. 67 /// 68 /// Implemented as 69 /// 70 /// a mod b = 71 /// let remainder = srem a, b; 72 /// negative = a < 0 in 73 /// select negative, remainder + b, remainder. 74 Value visitModExpr(AffineBinaryOpExpr expr) { 75 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 76 if (!rhsConst) { 77 emitError( 78 loc, 79 "semi-affine expressions (modulo by non-const) are not supported"); 80 return nullptr; 81 } 82 if (rhsConst.getValue() <= 0) { 83 emitError(loc, "modulo by non-positive value is not supported"); 84 return nullptr; 85 } 86 87 auto lhs = visit(expr.getLHS()); 88 auto rhs = visit(expr.getRHS()); 89 assert(lhs && rhs && "unexpected affine expr lowering failure"); 90 91 Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs); 92 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 93 Value isRemainderNegative = builder.create<arith::CmpIOp>( 94 loc, arith::CmpIPredicate::slt, remainder, zeroCst); 95 Value correctedRemainder = 96 builder.create<arith::AddIOp>(loc, remainder, rhs); 97 Value result = builder.create<arith::SelectOp>( 98 loc, isRemainderNegative, correctedRemainder, remainder); 99 return result; 100 } 101 102 /// Floor division operation (rounds towards negative infinity). 103 /// 104 /// For positive divisors, it can be implemented without branching and with a 105 /// single division operation as 106 /// 107 /// a floordiv b = 108 /// let negative = a < 0 in 109 /// let absolute = negative ? -a - 1 : a in 110 /// let quotient = absolute / b in 111 /// negative ? -quotient - 1 : quotient 112 Value visitFloorDivExpr(AffineBinaryOpExpr expr) { 113 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 114 if (!rhsConst) { 115 emitError( 116 loc, 117 "semi-affine expressions (division by non-const) are not supported"); 118 return nullptr; 119 } 120 if (rhsConst.getValue() <= 0) { 121 emitError(loc, "division by non-positive value is not supported"); 122 return nullptr; 123 } 124 125 auto lhs = visit(expr.getLHS()); 126 auto rhs = visit(expr.getRHS()); 127 assert(lhs && rhs && "unexpected affine expr lowering failure"); 128 129 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 130 Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1); 131 Value negative = builder.create<arith::CmpIOp>( 132 loc, arith::CmpIPredicate::slt, lhs, zeroCst); 133 Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs); 134 Value dividend = 135 builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs); 136 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs); 137 Value correctedQuotient = 138 builder.create<arith::SubIOp>(loc, noneCst, quotient); 139 Value result = builder.create<arith::SelectOp>(loc, negative, 140 correctedQuotient, quotient); 141 return result; 142 } 143 144 /// Ceiling division operation (rounds towards positive infinity). 145 /// 146 /// For positive divisors, it can be implemented without branching and with a 147 /// single division operation as 148 /// 149 /// a ceildiv b = 150 /// let negative = a <= 0 in 151 /// let absolute = negative ? -a : a - 1 in 152 /// let quotient = absolute / b in 153 /// negative ? -quotient : quotient + 1 154 Value visitCeilDivExpr(AffineBinaryOpExpr expr) { 155 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 156 if (!rhsConst) { 157 emitError(loc) << "semi-affine expressions (division by non-const) are " 158 "not supported"; 159 return nullptr; 160 } 161 if (rhsConst.getValue() <= 0) { 162 emitError(loc, "division by non-positive value is not supported"); 163 return nullptr; 164 } 165 auto lhs = visit(expr.getLHS()); 166 auto rhs = visit(expr.getRHS()); 167 assert(lhs && rhs && "unexpected affine expr lowering failure"); 168 169 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 170 Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1); 171 Value nonPositive = builder.create<arith::CmpIOp>( 172 loc, arith::CmpIPredicate::sle, lhs, zeroCst); 173 Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs); 174 Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst); 175 Value dividend = 176 builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented); 177 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs); 178 Value negatedQuotient = 179 builder.create<arith::SubIOp>(loc, zeroCst, quotient); 180 Value incrementedQuotient = 181 builder.create<arith::AddIOp>(loc, quotient, oneCst); 182 Value result = builder.create<arith::SelectOp>( 183 loc, nonPositive, negatedQuotient, incrementedQuotient); 184 return result; 185 } 186 187 Value visitConstantExpr(AffineConstantExpr expr) { 188 auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue()); 189 return op.getResult(); 190 } 191 192 Value visitDimExpr(AffineDimExpr expr) { 193 assert(expr.getPosition() < dimValues.size() && 194 "affine dim position out of range"); 195 return dimValues[expr.getPosition()]; 196 } 197 198 Value visitSymbolExpr(AffineSymbolExpr expr) { 199 assert(expr.getPosition() < symbolValues.size() && 200 "symbol dim position out of range"); 201 return symbolValues[expr.getPosition()]; 202 } 203 204 private: 205 OpBuilder &builder; 206 ValueRange dimValues; 207 ValueRange symbolValues; 208 209 Location loc; 210 }; 211 } // namespace 212 213 /// Create a sequence of operations that implement the `expr` applied to the 214 /// given dimension and symbol values. 215 mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc, 216 AffineExpr expr, ValueRange dimValues, 217 ValueRange symbolValues) { 218 return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); 219 } 220 221 /// Create a sequence of operations that implement the `affineMap` applied to 222 /// the given `operands` (as it it were an AffineApplyOp). 223 Optional<SmallVector<Value, 8>> mlir::expandAffineMap(OpBuilder &builder, 224 Location loc, 225 AffineMap affineMap, 226 ValueRange operands) { 227 auto numDims = affineMap.getNumDims(); 228 auto expanded = llvm::to_vector<8>( 229 llvm::map_range(affineMap.getResults(), 230 [numDims, &builder, loc, operands](AffineExpr expr) { 231 return expandAffineExpr(builder, loc, expr, 232 operands.take_front(numDims), 233 operands.drop_front(numDims)); 234 })); 235 if (llvm::all_of(expanded, [](Value v) { return v; })) 236 return expanded; 237 return None; 238 } 239 240 /// Promotes the `then` or the `else` block of `ifOp` (depending on whether 241 /// `elseBlock` is false or true) into `ifOp`'s containing block, and discards 242 /// the rest of the op. 243 static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) { 244 if (elseBlock) 245 assert(ifOp.hasElse() && "else block expected"); 246 247 Block *destBlock = ifOp->getBlock(); 248 Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock(); 249 destBlock->getOperations().splice( 250 Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(), 251 std::prev(srcBlock->end())); 252 ifOp.erase(); 253 } 254 255 /// Returns the outermost affine.for/parallel op that the `ifOp` is invariant 256 /// on. The `ifOp` could be hoisted and placed right before such an operation. 257 /// This method assumes that the ifOp has been canonicalized (to be correct and 258 /// effective). 259 static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) { 260 // Walk up the parents past all for op that this conditional is invariant on. 261 auto ifOperands = ifOp.getOperands(); 262 auto *res = ifOp.getOperation(); 263 while (!isa<func::FuncOp>(res->getParentOp())) { 264 auto *parentOp = res->getParentOp(); 265 if (auto forOp = dyn_cast<AffineForOp>(parentOp)) { 266 if (llvm::is_contained(ifOperands, forOp.getInductionVar())) 267 break; 268 } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) { 269 for (auto iv : parallelOp.getIVs()) 270 if (llvm::is_contained(ifOperands, iv)) 271 break; 272 } else if (!isa<AffineIfOp>(parentOp)) { 273 // Won't walk up past anything other than affine.for/if ops. 274 break; 275 } 276 // You can always hoist up past any affine.if ops. 277 res = parentOp; 278 } 279 return res; 280 } 281 282 /// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over 283 /// `hoistOverOp`. Returns the new hoisted op if any hoisting happened, 284 /// otherwise the same `ifOp`. 285 static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) { 286 // No hoisting to do. 287 if (hoistOverOp == ifOp) 288 return ifOp; 289 290 // Create the hoisted 'if' first. Then, clone the op we are hoisting over for 291 // the else block. Then drop the else block of the original 'if' in the 'then' 292 // branch while promoting its then block, and analogously drop the 'then' 293 // block of the original 'if' from the 'else' branch while promoting its else 294 // block. 295 BlockAndValueMapping operandMap; 296 OpBuilder b(hoistOverOp); 297 auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(), 298 ifOp.getOperands(), 299 /*elseBlock=*/true); 300 301 // Create a clone of hoistOverOp to use for the else branch of the hoisted 302 // conditional. The else block may get optimized away if empty. 303 Operation *hoistOverOpClone = nullptr; 304 // We use this unique name to identify/find `ifOp`'s clone in the else 305 // version. 306 StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting"); 307 operandMap.clear(); 308 b.setInsertionPointAfter(hoistOverOp); 309 // We'll set an attribute to identify this op in a clone of this sub-tree. 310 ifOp->setAttr(idForIfOp, b.getBoolAttr(true)); 311 hoistOverOpClone = b.clone(*hoistOverOp, operandMap); 312 313 // Promote the 'then' block of the original affine.if in the then version. 314 promoteIfBlock(ifOp, /*elseBlock=*/false); 315 316 // Move the then version to the hoisted if op's 'then' block. 317 auto *thenBlock = hoistedIfOp.getThenBlock(); 318 thenBlock->getOperations().splice(thenBlock->begin(), 319 hoistOverOp->getBlock()->getOperations(), 320 Block::iterator(hoistOverOp)); 321 322 // Find the clone of the original affine.if op in the else version. 323 AffineIfOp ifCloneInElse; 324 hoistOverOpClone->walk([&](AffineIfOp ifClone) { 325 if (!ifClone->getAttr(idForIfOp)) 326 return WalkResult::advance(); 327 ifCloneInElse = ifClone; 328 return WalkResult::interrupt(); 329 }); 330 assert(ifCloneInElse && "if op clone should exist"); 331 // For the else block, promote the else block of the original 'if' if it had 332 // one; otherwise, the op itself is to be erased. 333 if (!ifCloneInElse.hasElse()) 334 ifCloneInElse.erase(); 335 else 336 promoteIfBlock(ifCloneInElse, /*elseBlock=*/true); 337 338 // Move the else version into the else block of the hoisted if op. 339 auto *elseBlock = hoistedIfOp.getElseBlock(); 340 elseBlock->getOperations().splice( 341 elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(), 342 Block::iterator(hoistOverOpClone)); 343 344 return hoistedIfOp; 345 } 346 347 LogicalResult 348 mlir::affineParallelize(AffineForOp forOp, 349 ArrayRef<LoopReduction> parallelReductions) { 350 // Fail early if there are iter arguments that are not reductions. 351 unsigned numReductions = parallelReductions.size(); 352 if (numReductions != forOp.getNumIterOperands()) 353 return failure(); 354 355 Location loc = forOp.getLoc(); 356 OpBuilder outsideBuilder(forOp); 357 AffineMap lowerBoundMap = forOp.getLowerBoundMap(); 358 ValueRange lowerBoundOperands = forOp.getLowerBoundOperands(); 359 AffineMap upperBoundMap = forOp.getUpperBoundMap(); 360 ValueRange upperBoundOperands = forOp.getUpperBoundOperands(); 361 362 // Creating empty 1-D affine.parallel op. 363 auto reducedValues = llvm::to_vector<4>(llvm::map_range( 364 parallelReductions, [](const LoopReduction &red) { return red.value; })); 365 auto reductionKinds = llvm::to_vector<4>(llvm::map_range( 366 parallelReductions, [](const LoopReduction &red) { return red.kind; })); 367 AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>( 368 loc, ValueRange(reducedValues).getTypes(), reductionKinds, 369 llvm::makeArrayRef(lowerBoundMap), lowerBoundOperands, 370 llvm::makeArrayRef(upperBoundMap), upperBoundOperands, 371 llvm::makeArrayRef(forOp.getStep())); 372 // Steal the body of the old affine for op. 373 newPloop.getRegion().takeBody(forOp.getRegion()); 374 Operation *yieldOp = &newPloop.getBody()->back(); 375 376 // Handle the initial values of reductions because the parallel loop always 377 // starts from the neutral value. 378 SmallVector<Value> newResults; 379 newResults.reserve(numReductions); 380 for (unsigned i = 0; i < numReductions; ++i) { 381 Value init = forOp.getIterOperands()[i]; 382 // This works because we are only handling single-op reductions at the 383 // moment. A switch on reduction kind or a mechanism to collect operations 384 // participating in the reduction will be necessary for multi-op reductions. 385 Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp(); 386 assert(reductionOp && "yielded value is expected to be produced by an op"); 387 outsideBuilder.getInsertionBlock()->getOperations().splice( 388 outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(), 389 reductionOp); 390 reductionOp->setOperands({init, newPloop->getResult(i)}); 391 forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0)); 392 } 393 394 // Update the loop terminator to yield reduced values bypassing the reduction 395 // operation itself (now moved outside of the loop) and erase the block 396 // arguments that correspond to reductions. Note that the loop always has one 397 // "main" induction variable whenc coming from a non-parallel for. 398 unsigned numIVs = 1; 399 yieldOp->setOperands(reducedValues); 400 newPloop.getBody()->eraseArguments( 401 llvm::to_vector<4>(llvm::seq<unsigned>(numIVs, numReductions + numIVs))); 402 403 forOp.erase(); 404 return success(); 405 } 406 407 // Returns success if any hoisting happened. 408 LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { 409 // Bail out early if the ifOp returns a result. TODO: Consider how to 410 // properly support this case. 411 if (ifOp.getNumResults() != 0) 412 return failure(); 413 414 // Apply canonicalization patterns and folding - this is necessary for the 415 // hoisting check to be correct (operands should be composed), and to be more 416 // effective (no unused operands). Since the pattern rewriter's folding is 417 // entangled with application of patterns, we may fold/end up erasing the op, 418 // in which case we return with `folded` being set. 419 RewritePatternSet patterns(ifOp.getContext()); 420 AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); 421 bool erased; 422 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 423 (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased); 424 if (erased) { 425 if (folded) 426 *folded = true; 427 return failure(); 428 } 429 if (folded) 430 *folded = false; 431 432 // The folding above should have ensured this, but the affine.if's 433 // canonicalization is missing composition of affine.applys into it. 434 assert(llvm::all_of(ifOp.getOperands(), 435 [](Value v) { 436 return isTopLevelValue(v) || isForInductionVar(v); 437 }) && 438 "operands not composed"); 439 440 // We are going hoist as high as possible. 441 // TODO: this could be customized in the future. 442 auto *hoistOverOp = getOutermostInvariantForOp(ifOp); 443 444 AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp); 445 // Nothing to hoist over. 446 if (hoistedIfOp == ifOp) 447 return failure(); 448 449 // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up 450 // a sequence of affine.fors that are all perfectly nested). 451 (void)applyPatternsAndFoldGreedily( 452 hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(), 453 frozenPatterns); 454 455 return success(); 456 } 457 458 // Return the min expr after replacing the given dim. 459 AffineExpr mlir::substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min, 460 AffineExpr max, bool positivePath) { 461 if (e == dim) 462 return positivePath ? min : max; 463 if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) { 464 AffineExpr lhs = bin.getLHS(); 465 AffineExpr rhs = bin.getRHS(); 466 if (bin.getKind() == mlir::AffineExprKind::Add) 467 return substWithMin(lhs, dim, min, max, positivePath) + 468 substWithMin(rhs, dim, min, max, positivePath); 469 470 auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>(); 471 auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>(); 472 if (c1 && c1.getValue() < 0) 473 return getAffineBinaryOpExpr( 474 bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath)); 475 if (c2 && c2.getValue() < 0) 476 return getAffineBinaryOpExpr( 477 bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2); 478 return getAffineBinaryOpExpr( 479 bin.getKind(), substWithMin(lhs, dim, min, max, positivePath), 480 substWithMin(rhs, dim, min, max, positivePath)); 481 } 482 return e; 483 } 484 485 void mlir::normalizeAffineParallel(AffineParallelOp op) { 486 // Loops with min/max in bounds are not normalized at the moment. 487 if (op.hasMinMaxBounds()) 488 return; 489 490 AffineMap lbMap = op.getLowerBoundsMap(); 491 SmallVector<int64_t, 8> steps = op.getSteps(); 492 // No need to do any work if the parallel op is already normalized. 493 bool isAlreadyNormalized = 494 llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) { 495 int64_t step = std::get<0>(tuple); 496 auto lbExpr = 497 std::get<1>(tuple).template dyn_cast<AffineConstantExpr>(); 498 return lbExpr && lbExpr.getValue() == 0 && step == 1; 499 }); 500 if (isAlreadyNormalized) 501 return; 502 503 AffineValueMap ranges; 504 AffineValueMap::difference(op.getUpperBoundsValueMap(), 505 op.getLowerBoundsValueMap(), &ranges); 506 auto builder = OpBuilder::atBlockBegin(op.getBody()); 507 auto zeroExpr = builder.getAffineConstantExpr(0); 508 SmallVector<AffineExpr, 8> lbExprs; 509 SmallVector<AffineExpr, 8> ubExprs; 510 for (unsigned i = 0, e = steps.size(); i < e; ++i) { 511 int64_t step = steps[i]; 512 513 // Adjust the lower bound to be 0. 514 lbExprs.push_back(zeroExpr); 515 516 // Adjust the upper bound expression: 'range / step'. 517 AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step); 518 ubExprs.push_back(ubExpr); 519 520 // Adjust the corresponding IV: 'lb + i * step'. 521 BlockArgument iv = op.getBody()->getArgument(i); 522 AffineExpr lbExpr = lbMap.getResult(i); 523 unsigned nDims = lbMap.getNumDims(); 524 auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step; 525 auto map = AffineMap::get(/*dimCount=*/nDims + 1, 526 /*symbolCount=*/lbMap.getNumSymbols(), expr); 527 528 // Use an 'affine.apply' op that will be simplified later in subsequent 529 // canonicalizations. 530 OperandRange lbOperands = op.getLowerBoundsOperands(); 531 OperandRange dimOperands = lbOperands.take_front(nDims); 532 OperandRange symbolOperands = lbOperands.drop_front(nDims); 533 SmallVector<Value, 8> applyOperands{dimOperands}; 534 applyOperands.push_back(iv); 535 applyOperands.append(symbolOperands.begin(), symbolOperands.end()); 536 auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands); 537 iv.replaceAllUsesExcept(apply, apply); 538 } 539 540 SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1); 541 op.setSteps(newSteps); 542 auto newLowerMap = AffineMap::get( 543 /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext()); 544 op.setLowerBounds({}, newLowerMap); 545 auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(), 546 ubExprs, op.getContext()); 547 op.setUpperBounds(ranges.getOperands(), newUpperMap); 548 } 549 550 /// Normalizes affine.for ops. If the affine.for op has only a single iteration 551 /// only then it is simply promoted, else it is normalized in the traditional 552 /// way, by converting the lower bound to zero and loop step to one. The upper 553 /// bound is set to the trip count of the loop. For now, original loops must 554 /// have lower bound with a single result only. There is no such restriction on 555 /// upper bounds. 556 LogicalResult mlir::normalizeAffineFor(AffineForOp op) { 557 if (succeeded(promoteIfSingleIteration(op))) 558 return success(); 559 560 // Check if the forop is already normalized. 561 if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) && 562 (op.getStep() == 1)) 563 return success(); 564 565 // Check if the lower bound has a single result only. Loops with a max lower 566 // bound can't be normalized without additional support like 567 // affine.execute_region's. If the lower bound does not have a single result 568 // then skip this op. 569 if (op.getLowerBoundMap().getNumResults() != 1) 570 return failure(); 571 572 Location loc = op.getLoc(); 573 OpBuilder opBuilder(op); 574 int64_t origLoopStep = op.getStep(); 575 576 // Calculate upperBound for normalized loop. 577 SmallVector<Value, 4> ubOperands; 578 AffineBound lb = op.getLowerBound(); 579 AffineBound ub = op.getUpperBound(); 580 ubOperands.reserve(ub.getNumOperands() + lb.getNumOperands()); 581 AffineMap origLbMap = lb.getMap(); 582 AffineMap origUbMap = ub.getMap(); 583 584 // Add dimension operands from upper/lower bound. 585 for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) 586 ubOperands.push_back(ub.getOperand(j)); 587 for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j) 588 ubOperands.push_back(lb.getOperand(j)); 589 590 // Add symbol operands from upper/lower bound. 591 for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) 592 ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j)); 593 for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) 594 ubOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); 595 596 // Add original result expressions from lower/upper bound map. 597 SmallVector<AffineExpr, 1> origLbExprs(origLbMap.getResults().begin(), 598 origLbMap.getResults().end()); 599 SmallVector<AffineExpr, 2> origUbExprs(origUbMap.getResults().begin(), 600 origUbMap.getResults().end()); 601 SmallVector<AffineExpr, 4> newUbExprs; 602 603 // The original upperBound can have more than one result. For the new 604 // upperBound of this loop, take difference of all possible combinations of 605 // the ub results and lb result and ceildiv with the loop step. For e.g., 606 // 607 // affine.for %i1 = 0 to min affine_map<(d0)[] -> (d0 + 32, 1024)>(%i0) 608 // will have an upperBound map as, 609 // affine_map<(d0)[] -> (((d0 + 32) - 0) ceildiv 1, (1024 - 0) ceildiv 610 // 1)>(%i0) 611 // 612 // Insert all combinations of upper/lower bound results. 613 for (unsigned i = 0, e = origUbExprs.size(); i < e; ++i) { 614 newUbExprs.push_back( 615 (origUbExprs[i] - origLbExprs[0]).ceilDiv(origLoopStep)); 616 } 617 618 // Construct newUbMap. 619 AffineMap newUbMap = 620 AffineMap::get(origLbMap.getNumDims() + origUbMap.getNumDims(), 621 origLbMap.getNumSymbols() + origUbMap.getNumSymbols(), 622 newUbExprs, opBuilder.getContext()); 623 canonicalizeMapAndOperands(&newUbMap, &ubOperands); 624 625 SmallVector<Value, 4> lbOperands(lb.getOperands().begin(), 626 lb.getOperands().begin() + 627 lb.getMap().getNumDims()); 628 629 // Normalize the loop. 630 op.setUpperBound(ubOperands, newUbMap); 631 op.setLowerBound({}, opBuilder.getConstantAffineMap(0)); 632 op.setStep(1); 633 634 // Calculate the Value of new loopIV. Create affine.apply for the value of 635 // the loopIV in normalized loop. 636 opBuilder.setInsertionPointToStart(op.getBody()); 637 // Add an extra dim operand for loopIV. 638 lbOperands.push_back(op.getInductionVar()); 639 // Add symbol operands from lower bound. 640 for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) 641 lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); 642 643 AffineExpr origIVExpr = opBuilder.getAffineDimExpr(lb.getMap().getNumDims()); 644 AffineExpr newIVExpr = origIVExpr * origLoopStep + origLbMap.getResult(0); 645 AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1, 646 origLbMap.getNumSymbols(), newIVExpr); 647 canonicalizeMapAndOperands(&ivMap, &lbOperands); 648 Operation *newIV = opBuilder.create<AffineApplyOp>(loc, ivMap, lbOperands); 649 op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); 650 return success(); 651 } 652 653 /// Ensure that all operations that could be executed after `start` 654 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path 655 /// between the operations) do not have the potential memory effect 656 /// `EffectType` on `memOp`. `memOp` is an operation that reads or writes to 657 /// a memref. For example, if `EffectType` is MemoryEffects::Write, this method 658 /// will check if there is no write to the memory between `start` and `memOp` 659 /// that would change the read within `memOp`. 660 template <typename EffectType, typename T> 661 static bool hasNoInterveningEffect(Operation *start, T memOp) { 662 Value memref = memOp.getMemRef(); 663 bool isOriginalAllocation = memref.getDefiningOp<memref::AllocaOp>() || 664 memref.getDefiningOp<memref::AllocOp>(); 665 666 // A boolean representing whether an intervening operation could have impacted 667 // memOp. 668 bool hasSideEffect = false; 669 670 // Check whether the effect on memOp can be caused by a given operation op. 671 std::function<void(Operation *)> checkOperation = [&](Operation *op) { 672 // If the effect has alreay been found, early exit, 673 if (hasSideEffect) 674 return; 675 676 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) { 677 SmallVector<MemoryEffects::EffectInstance, 1> effects; 678 memEffect.getEffects(effects); 679 680 bool opMayHaveEffect = false; 681 for (auto effect : effects) { 682 // If op causes EffectType on a potentially aliasing location for 683 // memOp, mark as having the effect. 684 if (isa<EffectType>(effect.getEffect())) { 685 if (isOriginalAllocation && effect.getValue() && 686 (effect.getValue().getDefiningOp<memref::AllocaOp>() || 687 effect.getValue().getDefiningOp<memref::AllocOp>())) { 688 if (effect.getValue() != memref) 689 continue; 690 } 691 opMayHaveEffect = true; 692 break; 693 } 694 } 695 696 if (!opMayHaveEffect) 697 return; 698 699 // If the side effect comes from an affine read or write, try to 700 // prove the side effecting `op` cannot reach `memOp`. 701 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { 702 MemRefAccess srcAccess(op); 703 MemRefAccess destAccess(memOp); 704 // Dependence analysis is only correct if both ops operate on the same 705 // memref. 706 if (srcAccess.memref == destAccess.memref) { 707 FlatAffineValueConstraints dependenceConstraints; 708 709 // Number of loops containing the start op and the ending operation. 710 unsigned minSurroundingLoops = 711 getNumCommonSurroundingLoops(*start, *memOp); 712 713 // Number of loops containing the operation `op` which has the 714 // potential memory side effect and can occur on a path between 715 // `start` and `memOp`. 716 unsigned nsLoops = getNumCommonSurroundingLoops(*op, *memOp); 717 718 // For ease, let's consider the case that `op` is a store and we're 719 // looking for other potential stores (e.g `op`) that overwrite memory 720 // after `start`, and before being read in `memOp`. In this case, we 721 // only need to consider other potential stores with depth > 722 // minSurrounding loops since `start` would overwrite any store with a 723 // smaller number of surrounding loops before. 724 unsigned d; 725 for (d = nsLoops + 1; d > minSurroundingLoops; d--) { 726 DependenceResult result = checkMemrefAccessDependence( 727 srcAccess, destAccess, d, &dependenceConstraints, 728 /*dependenceComponents=*/nullptr); 729 if (hasDependence(result)) { 730 hasSideEffect = true; 731 return; 732 } 733 } 734 735 // No side effect was seen, simply return. 736 return; 737 } 738 } 739 hasSideEffect = true; 740 return; 741 } 742 743 if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) { 744 // Recurse into the regions for this op and check whether the internal 745 // operations may have the side effect `EffectType` on memOp. 746 for (Region ®ion : op->getRegions()) 747 for (Block &block : region) 748 for (Operation &op : block) 749 checkOperation(&op); 750 return; 751 } 752 753 // Otherwise, conservatively assume generic operations have the effect 754 // on the operation 755 hasSideEffect = true; 756 }; 757 758 // Check all paths from ancestor op `parent` to the operation `to` for the 759 // effect. It is known that `to` must be contained within `parent`. 760 auto until = [&](Operation *parent, Operation *to) { 761 // TODO check only the paths from `parent` to `to`. 762 // Currently we fallback and check the entire parent op, rather than 763 // just the paths from the parent path, stopping after reaching `to`. 764 // This is conservatively correct, but could be made more aggressive. 765 assert(parent->isAncestor(to)); 766 checkOperation(parent); 767 }; 768 769 // Check for all paths from operation `from` to operation `untilOp` for the 770 // given memory effect. 771 std::function<void(Operation *, Operation *)> recur = 772 [&](Operation *from, Operation *untilOp) { 773 assert( 774 from->getParentRegion()->isAncestor(untilOp->getParentRegion()) && 775 "Checking for side effect between two operations without a common " 776 "ancestor"); 777 778 // If the operations are in different regions, recursively consider all 779 // path from `from` to the parent of `to` and all paths from the parent 780 // of `to` to `to`. 781 if (from->getParentRegion() != untilOp->getParentRegion()) { 782 recur(from, untilOp->getParentOp()); 783 until(untilOp->getParentOp(), untilOp); 784 return; 785 } 786 787 // Now, assuming that `from` and `to` exist in the same region, perform 788 // a CFG traversal to check all the relevant operations. 789 790 // Additional blocks to consider. 791 SmallVector<Block *, 2> todoBlocks; 792 { 793 // First consider the parent block of `from` an check all operations 794 // after `from`. 795 for (auto iter = ++from->getIterator(), end = from->getBlock()->end(); 796 iter != end && &*iter != untilOp; ++iter) { 797 checkOperation(&*iter); 798 } 799 800 // If the parent of `from` doesn't contain `to`, add the successors 801 // to the list of blocks to check. 802 if (untilOp->getBlock() != from->getBlock()) 803 for (Block *succ : from->getBlock()->getSuccessors()) 804 todoBlocks.push_back(succ); 805 } 806 807 SmallPtrSet<Block *, 4> done; 808 // Traverse the CFG until hitting `to`. 809 while (!todoBlocks.empty()) { 810 Block *blk = todoBlocks.pop_back_val(); 811 if (done.count(blk)) 812 continue; 813 done.insert(blk); 814 for (auto &op : *blk) { 815 if (&op == untilOp) 816 break; 817 checkOperation(&op); 818 if (&op == blk->getTerminator()) 819 for (Block *succ : blk->getSuccessors()) 820 todoBlocks.push_back(succ); 821 } 822 } 823 }; 824 recur(start, memOp); 825 return !hasSideEffect; 826 } 827 828 /// Attempt to eliminate loadOp by replacing it with a value stored into memory 829 /// which the load is guaranteed to retrieve. This check involves three 830 /// components: 1) The store and load must be on the same location 2) The store 831 /// must dominate (and therefore must always occur prior to) the load 3) No 832 /// other operations will overwrite the memory loaded between the given load 833 /// and store. If such a value exists, the replaced `loadOp` will be added to 834 /// `loadOpsToErase` and its memref will be added to `memrefsToErase`. 835 static LogicalResult forwardStoreToLoad( 836 AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase, 837 SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo) { 838 839 // The store op candidate for forwarding that satisfies all conditions 840 // to replace the load, if any. 841 Operation *lastWriteStoreOp = nullptr; 842 843 for (auto *user : loadOp.getMemRef().getUsers()) { 844 auto storeOp = dyn_cast<AffineWriteOpInterface>(user); 845 if (!storeOp) 846 continue; 847 MemRefAccess srcAccess(storeOp); 848 MemRefAccess destAccess(loadOp); 849 850 // 1. Check if the store and the load have mathematically equivalent 851 // affine access functions; this implies that they statically refer to the 852 // same single memref element. As an example this filters out cases like: 853 // store %A[%i0 + 1] 854 // load %A[%i0] 855 // store %A[%M] 856 // load %A[%N] 857 // Use the AffineValueMap difference based memref access equality checking. 858 if (srcAccess != destAccess) 859 continue; 860 861 // 2. The store has to dominate the load op to be candidate. 862 if (!domInfo.dominates(storeOp, loadOp)) 863 continue; 864 865 // 3. Ensure there is no intermediate operation which could replace the 866 // value in memory. 867 if (!hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp)) 868 continue; 869 870 // We now have a candidate for forwarding. 871 assert(lastWriteStoreOp == nullptr && 872 "multiple simulataneous replacement stores"); 873 lastWriteStoreOp = storeOp; 874 } 875 876 if (!lastWriteStoreOp) 877 return failure(); 878 879 // Perform the actual store to load forwarding. 880 Value storeVal = 881 cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore(); 882 // Check if 2 values have the same shape. This is needed for affine vector 883 // loads and stores. 884 if (storeVal.getType() != loadOp.getValue().getType()) 885 return failure(); 886 loadOp.getValue().replaceAllUsesWith(storeVal); 887 // Record the memref for a later sweep to optimize away. 888 memrefsToErase.insert(loadOp.getMemRef()); 889 // Record this to erase later. 890 loadOpsToErase.push_back(loadOp); 891 return success(); 892 } 893 894 // This attempts to find stores which have no impact on the final result. 895 // A writing op writeA will be eliminated if there exists an op writeB if 896 // 1) writeA and writeB have mathematically equivalent affine access functions. 897 // 2) writeB postdominates writeA. 898 // 3) There is no potential read between writeA and writeB. 899 static void findUnusedStore(AffineWriteOpInterface writeA, 900 SmallVectorImpl<Operation *> &opsToErase, 901 PostDominanceInfo &postDominanceInfo) { 902 903 for (Operation *user : writeA.getMemRef().getUsers()) { 904 // Only consider writing operations. 905 auto writeB = dyn_cast<AffineWriteOpInterface>(user); 906 if (!writeB) 907 continue; 908 909 // The operations must be distinct. 910 if (writeB == writeA) 911 continue; 912 913 // Both operations must lie in the same region. 914 if (writeB->getParentRegion() != writeA->getParentRegion()) 915 continue; 916 917 // Both operations must write to the same memory. 918 MemRefAccess srcAccess(writeB); 919 MemRefAccess destAccess(writeA); 920 921 if (srcAccess != destAccess) 922 continue; 923 924 // writeB must postdominate writeA. 925 if (!postDominanceInfo.postDominates(writeB, writeA)) 926 continue; 927 928 // There cannot be an operation which reads from memory between 929 // the two writes. 930 if (!hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB)) 931 continue; 932 933 opsToErase.push_back(writeA); 934 break; 935 } 936 } 937 938 // The load to load forwarding / redundant load elimination is similar to the 939 // store to load forwarding. 940 // loadA will be be replaced with loadB if: 941 // 1) loadA and loadB have mathematically equivalent affine access functions. 942 // 2) loadB dominates loadA. 943 // 3) There is no write between loadA and loadB. 944 static void loadCSE(AffineReadOpInterface loadA, 945 SmallVectorImpl<Operation *> &loadOpsToErase, 946 DominanceInfo &domInfo) { 947 SmallVector<AffineReadOpInterface, 4> loadCandidates; 948 for (auto *user : loadA.getMemRef().getUsers()) { 949 auto loadB = dyn_cast<AffineReadOpInterface>(user); 950 if (!loadB || loadB == loadA) 951 continue; 952 953 MemRefAccess srcAccess(loadB); 954 MemRefAccess destAccess(loadA); 955 956 // 1. The accesses have to be to the same location. 957 if (srcAccess != destAccess) { 958 continue; 959 } 960 961 // 2. The store has to dominate the load op to be candidate. 962 if (!domInfo.dominates(loadB, loadA)) 963 continue; 964 965 // 3. There is no write between loadA and loadB. 966 if (!hasNoInterveningEffect<MemoryEffects::Write>(loadB.getOperation(), 967 loadA)) 968 continue; 969 970 // Check if two values have the same shape. This is needed for affine vector 971 // loads. 972 if (loadB.getValue().getType() != loadA.getValue().getType()) 973 continue; 974 975 loadCandidates.push_back(loadB); 976 } 977 978 // Of the legal load candidates, use the one that dominates all others 979 // to minimize the subsequent need to loadCSE 980 Value loadB; 981 for (AffineReadOpInterface option : loadCandidates) { 982 if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) { 983 return depStore == option || 984 domInfo.dominates(option.getOperation(), 985 depStore.getOperation()); 986 })) { 987 loadB = option.getValue(); 988 break; 989 } 990 } 991 992 if (loadB) { 993 loadA.getValue().replaceAllUsesWith(loadB); 994 // Record this to erase later. 995 loadOpsToErase.push_back(loadA); 996 } 997 } 998 999 // The store to load forwarding and load CSE rely on three conditions: 1000 // 1001 // 1) store/load providing a replacement value and load being replaced need to 1002 // have mathematically equivalent affine access functions (checked after full 1003 // composition of load/store operands); this implies that they access the same 1004 // single memref element for all iterations of the common surrounding loop, 1005 // 1006 // 2) the store/load op should dominate the load op, 1007 // 1008 // 3) no operation that may write to memory read by the load being replaced can 1009 // occur after executing the instruction (load or store) providing the 1010 // replacement value and before the load being replaced (thus potentially 1011 // allowing overwriting the memory read by the load). 1012 // 1013 // The above conditions are simple to check, sufficient, and powerful for most 1014 // cases in practice - they are sufficient, but not necessary --- since they 1015 // don't reason about loops that are guaranteed to execute at least once or 1016 // multiple sources to forward from. 1017 // 1018 // TODO: more forwarding can be done when support for 1019 // loop/conditional live-out SSA values is available. 1020 // TODO: do general dead store elimination for memref's. This pass 1021 // currently only eliminates the stores only if no other loads/uses (other 1022 // than dealloc) remain. 1023 // 1024 void mlir::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo, 1025 PostDominanceInfo &postDomInfo) { 1026 // Load op's whose results were replaced by those forwarded from stores. 1027 SmallVector<Operation *, 8> opsToErase; 1028 1029 // A list of memref's that are potentially dead / could be eliminated. 1030 SmallPtrSet<Value, 4> memrefsToErase; 1031 1032 // Walk all load's and perform store to load forwarding. 1033 f.walk([&](AffineReadOpInterface loadOp) { 1034 if (failed( 1035 forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) { 1036 loadCSE(loadOp, opsToErase, domInfo); 1037 } 1038 }); 1039 1040 // Erase all load op's whose results were replaced with store fwd'ed ones. 1041 for (auto *op : opsToErase) 1042 op->erase(); 1043 opsToErase.clear(); 1044 1045 // Walk all store's and perform unused store elimination 1046 f.walk([&](AffineWriteOpInterface storeOp) { 1047 findUnusedStore(storeOp, opsToErase, postDomInfo); 1048 }); 1049 // Erase all store op's which don't impact the program 1050 for (auto *op : opsToErase) 1051 op->erase(); 1052 1053 // Check if the store fwd'ed memrefs are now left with only stores and can 1054 // thus be completely deleted. Note: the canonicalize pass should be able 1055 // to do this as well, but we'll do it here since we collected these anyway. 1056 for (auto memref : memrefsToErase) { 1057 // If the memref hasn't been alloc'ed in this function, skip. 1058 Operation *defOp = memref.getDefiningOp(); 1059 if (!defOp || !isa<memref::AllocOp>(defOp)) 1060 // TODO: if the memref was returned by a 'call' operation, we 1061 // could still erase it if the call had no side-effects. 1062 continue; 1063 if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) { 1064 return !isa<AffineWriteOpInterface, memref::DeallocOp>(ownerOp); 1065 })) 1066 continue; 1067 1068 // Erase all stores, the dealloc, and the alloc on the memref. 1069 for (auto *user : llvm::make_early_inc_range(memref.getUsers())) 1070 user->erase(); 1071 defOp->erase(); 1072 } 1073 } 1074 1075 // Perform the replacement in `op`. 1076 LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, 1077 Operation *op, 1078 ArrayRef<Value> extraIndices, 1079 AffineMap indexRemap, 1080 ArrayRef<Value> extraOperands, 1081 ArrayRef<Value> symbolOperands, 1082 bool allowNonDereferencingOps) { 1083 unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank(); 1084 (void)newMemRefRank; // unused in opt mode 1085 unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank(); 1086 (void)oldMemRefRank; // unused in opt mode 1087 if (indexRemap) { 1088 assert(indexRemap.getNumSymbols() == symbolOperands.size() && 1089 "symbolic operand count mismatch"); 1090 assert(indexRemap.getNumInputs() == 1091 extraOperands.size() + oldMemRefRank + symbolOperands.size()); 1092 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); 1093 } else { 1094 assert(oldMemRefRank + extraIndices.size() == newMemRefRank); 1095 } 1096 1097 // Assert same elemental type. 1098 assert(oldMemRef.getType().cast<MemRefType>().getElementType() == 1099 newMemRef.getType().cast<MemRefType>().getElementType()); 1100 1101 SmallVector<unsigned, 2> usePositions; 1102 for (const auto &opEntry : llvm::enumerate(op->getOperands())) { 1103 if (opEntry.value() == oldMemRef) 1104 usePositions.push_back(opEntry.index()); 1105 } 1106 1107 // If memref doesn't appear, nothing to do. 1108 if (usePositions.empty()) 1109 return success(); 1110 1111 if (usePositions.size() > 1) { 1112 // TODO: extend it for this case when needed (rare). 1113 assert(false && "multiple dereferencing uses in a single op not supported"); 1114 return failure(); 1115 } 1116 1117 unsigned memRefOperandPos = usePositions.front(); 1118 1119 OpBuilder builder(op); 1120 // The following checks if op is dereferencing memref and performs the access 1121 // index rewrites. 1122 auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op); 1123 if (!affMapAccInterface) { 1124 if (!allowNonDereferencingOps) { 1125 // Failure: memref used in a non-dereferencing context (potentially 1126 // escapes); no replacement in these cases unless allowNonDereferencingOps 1127 // is set. 1128 return failure(); 1129 } 1130 op->setOperand(memRefOperandPos, newMemRef); 1131 return success(); 1132 } 1133 // Perform index rewrites for the dereferencing op and then replace the op 1134 NamedAttribute oldMapAttrPair = 1135 affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); 1136 AffineMap oldMap = oldMapAttrPair.getValue().cast<AffineMapAttr>().getValue(); 1137 unsigned oldMapNumInputs = oldMap.getNumInputs(); 1138 SmallVector<Value, 4> oldMapOperands( 1139 op->operand_begin() + memRefOperandPos + 1, 1140 op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); 1141 1142 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. 1143 SmallVector<Value, 4> oldMemRefOperands; 1144 SmallVector<Value, 4> affineApplyOps; 1145 oldMemRefOperands.reserve(oldMemRefRank); 1146 if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { 1147 for (auto resultExpr : oldMap.getResults()) { 1148 auto singleResMap = AffineMap::get(oldMap.getNumDims(), 1149 oldMap.getNumSymbols(), resultExpr); 1150 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, 1151 oldMapOperands); 1152 oldMemRefOperands.push_back(afOp); 1153 affineApplyOps.push_back(afOp); 1154 } 1155 } else { 1156 oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); 1157 } 1158 1159 // Construct new indices as a remap of the old ones if a remapping has been 1160 // provided. The indices of a memref come right after it, i.e., 1161 // at position memRefOperandPos + 1. 1162 SmallVector<Value, 4> remapOperands; 1163 remapOperands.reserve(extraOperands.size() + oldMemRefRank + 1164 symbolOperands.size()); 1165 remapOperands.append(extraOperands.begin(), extraOperands.end()); 1166 remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); 1167 remapOperands.append(symbolOperands.begin(), symbolOperands.end()); 1168 1169 SmallVector<Value, 4> remapOutputs; 1170 remapOutputs.reserve(oldMemRefRank); 1171 1172 if (indexRemap && 1173 indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { 1174 // Remapped indices. 1175 for (auto resultExpr : indexRemap.getResults()) { 1176 auto singleResMap = AffineMap::get( 1177 indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); 1178 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, 1179 remapOperands); 1180 remapOutputs.push_back(afOp); 1181 affineApplyOps.push_back(afOp); 1182 } 1183 } else { 1184 // No remapping specified. 1185 remapOutputs.assign(remapOperands.begin(), remapOperands.end()); 1186 } 1187 1188 SmallVector<Value, 4> newMapOperands; 1189 newMapOperands.reserve(newMemRefRank); 1190 1191 // Prepend 'extraIndices' in 'newMapOperands'. 1192 for (Value extraIndex : extraIndices) { 1193 assert(extraIndex.getDefiningOp()->getNumResults() == 1 && 1194 "single result op's expected to generate these indices"); 1195 assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && 1196 "invalid memory op index"); 1197 newMapOperands.push_back(extraIndex); 1198 } 1199 1200 // Append 'remapOutputs' to 'newMapOperands'. 1201 newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); 1202 1203 // Create new fully composed AffineMap for new op to be created. 1204 assert(newMapOperands.size() == newMemRefRank); 1205 auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); 1206 // TODO: Avoid creating/deleting temporary AffineApplyOps here. 1207 fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); 1208 newMap = simplifyAffineMap(newMap); 1209 canonicalizeMapAndOperands(&newMap, &newMapOperands); 1210 // Remove any affine.apply's that became dead as a result of composition. 1211 for (Value value : affineApplyOps) 1212 if (value.use_empty()) 1213 value.getDefiningOp()->erase(); 1214 1215 OperationState state(op->getLoc(), op->getName()); 1216 // Construct the new operation using this memref. 1217 state.operands.reserve(op->getNumOperands() + extraIndices.size()); 1218 // Insert the non-memref operands. 1219 state.operands.append(op->operand_begin(), 1220 op->operand_begin() + memRefOperandPos); 1221 // Insert the new memref value. 1222 state.operands.push_back(newMemRef); 1223 1224 // Insert the new memref map operands. 1225 state.operands.append(newMapOperands.begin(), newMapOperands.end()); 1226 1227 // Insert the remaining operands unmodified. 1228 state.operands.append(op->operand_begin() + memRefOperandPos + 1 + 1229 oldMapNumInputs, 1230 op->operand_end()); 1231 1232 // Result types don't change. Both memref's are of the same elemental type. 1233 state.types.reserve(op->getNumResults()); 1234 for (auto result : op->getResults()) 1235 state.types.push_back(result.getType()); 1236 1237 // Add attribute for 'newMap', other Attributes do not change. 1238 auto newMapAttr = AffineMapAttr::get(newMap); 1239 for (auto namedAttr : op->getAttrs()) { 1240 if (namedAttr.getName() == oldMapAttrPair.getName()) 1241 state.attributes.push_back({namedAttr.getName(), newMapAttr}); 1242 else 1243 state.attributes.push_back(namedAttr); 1244 } 1245 1246 // Create the new operation. 1247 auto *repOp = builder.create(state); 1248 op->replaceAllUsesWith(repOp); 1249 op->erase(); 1250 1251 return success(); 1252 } 1253 1254 LogicalResult mlir::replaceAllMemRefUsesWith( 1255 Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices, 1256 AffineMap indexRemap, ArrayRef<Value> extraOperands, 1257 ArrayRef<Value> symbolOperands, Operation *domOpFilter, 1258 Operation *postDomOpFilter, bool allowNonDereferencingOps, 1259 bool replaceInDeallocOp) { 1260 unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank(); 1261 (void)newMemRefRank; // unused in opt mode 1262 unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank(); 1263 (void)oldMemRefRank; 1264 if (indexRemap) { 1265 assert(indexRemap.getNumSymbols() == symbolOperands.size() && 1266 "symbol operand count mismatch"); 1267 assert(indexRemap.getNumInputs() == 1268 extraOperands.size() + oldMemRefRank + symbolOperands.size()); 1269 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); 1270 } else { 1271 assert(oldMemRefRank + extraIndices.size() == newMemRefRank); 1272 } 1273 1274 // Assert same elemental type. 1275 assert(oldMemRef.getType().cast<MemRefType>().getElementType() == 1276 newMemRef.getType().cast<MemRefType>().getElementType()); 1277 1278 std::unique_ptr<DominanceInfo> domInfo; 1279 std::unique_ptr<PostDominanceInfo> postDomInfo; 1280 if (domOpFilter) 1281 domInfo = std::make_unique<DominanceInfo>( 1282 domOpFilter->getParentOfType<func::FuncOp>()); 1283 1284 if (postDomOpFilter) 1285 postDomInfo = std::make_unique<PostDominanceInfo>( 1286 postDomOpFilter->getParentOfType<func::FuncOp>()); 1287 1288 // Walk all uses of old memref; collect ops to perform replacement. We use a 1289 // DenseSet since an operation could potentially have multiple uses of a 1290 // memref (although rare), and the replacement later is going to erase ops. 1291 DenseSet<Operation *> opsToReplace; 1292 for (auto *op : oldMemRef.getUsers()) { 1293 // Skip this use if it's not dominated by domOpFilter. 1294 if (domOpFilter && !domInfo->dominates(domOpFilter, op)) 1295 continue; 1296 1297 // Skip this use if it's not post-dominated by postDomOpFilter. 1298 if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op)) 1299 continue; 1300 1301 // Skip dealloc's - no replacement is necessary, and a memref replacement 1302 // at other uses doesn't hurt these dealloc's. 1303 if (isa<memref::DeallocOp>(op) && !replaceInDeallocOp) 1304 continue; 1305 1306 // Check if the memref was used in a non-dereferencing context. It is fine 1307 // for the memref to be used in a non-dereferencing way outside of the 1308 // region where this replacement is happening. 1309 if (!isa<AffineMapAccessInterface>(*op)) { 1310 if (!allowNonDereferencingOps) { 1311 LLVM_DEBUG(llvm::dbgs() 1312 << "Memref replacement failed: non-deferencing memref op: \n" 1313 << *op << '\n'); 1314 return failure(); 1315 } 1316 // Non-dereferencing ops with the MemRefsNormalizable trait are 1317 // supported for replacement. 1318 if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) { 1319 LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a " 1320 "memrefs normalizable trait: \n" 1321 << *op << '\n'); 1322 return failure(); 1323 } 1324 } 1325 1326 // We'll first collect and then replace --- since replacement erases the op 1327 // that has the use, and that op could be postDomFilter or domFilter itself! 1328 opsToReplace.insert(op); 1329 } 1330 1331 for (auto *op : opsToReplace) { 1332 if (failed(replaceAllMemRefUsesWith( 1333 oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands, 1334 symbolOperands, allowNonDereferencingOps))) 1335 llvm_unreachable("memref replacement guaranteed to succeed here"); 1336 } 1337 1338 return success(); 1339 } 1340 1341 /// Given an operation, inserts one or more single result affine 1342 /// apply operations, results of which are exclusively used by this operation 1343 /// operation. The operands of these newly created affine apply ops are 1344 /// guaranteed to be loop iterators or terminal symbols of a function. 1345 /// 1346 /// Before 1347 /// 1348 /// affine.for %i = 0 to #map(%N) 1349 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) 1350 /// "send"(%idx, %A, ...) 1351 /// "compute"(%idx) 1352 /// 1353 /// After 1354 /// 1355 /// affine.for %i = 0 to #map(%N) 1356 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) 1357 /// "send"(%idx, %A, ...) 1358 /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) 1359 /// "compute"(%idx_) 1360 /// 1361 /// This allows applying different transformations on send and compute (for eg. 1362 /// different shifts/delays). 1363 /// 1364 /// Returns nullptr either if none of opInst's operands were the result of an 1365 /// affine.apply and thus there was no affine computation slice to create, or if 1366 /// all the affine.apply op's supplying operands to this opInst did not have any 1367 /// uses besides this opInst; otherwise returns the list of affine.apply 1368 /// operations created in output argument `sliceOps`. 1369 void mlir::createAffineComputationSlice( 1370 Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) { 1371 // Collect all operands that are results of affine apply ops. 1372 SmallVector<Value, 4> subOperands; 1373 subOperands.reserve(opInst->getNumOperands()); 1374 for (auto operand : opInst->getOperands()) 1375 if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp())) 1376 subOperands.push_back(operand); 1377 1378 // Gather sequence of AffineApplyOps reachable from 'subOperands'. 1379 SmallVector<Operation *, 4> affineApplyOps; 1380 getReachableAffineApplyOps(subOperands, affineApplyOps); 1381 // Skip transforming if there are no affine maps to compose. 1382 if (affineApplyOps.empty()) 1383 return; 1384 1385 // Check if all uses of the affine apply op's lie only in this op op, in 1386 // which case there would be nothing to do. 1387 bool localized = true; 1388 for (auto *op : affineApplyOps) { 1389 for (auto result : op->getResults()) { 1390 for (auto *user : result.getUsers()) { 1391 if (user != opInst) { 1392 localized = false; 1393 break; 1394 } 1395 } 1396 } 1397 } 1398 if (localized) 1399 return; 1400 1401 OpBuilder builder(opInst); 1402 SmallVector<Value, 4> composedOpOperands(subOperands); 1403 auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); 1404 fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); 1405 1406 // Create an affine.apply for each of the map results. 1407 sliceOps->reserve(composedMap.getNumResults()); 1408 for (auto resultExpr : composedMap.getResults()) { 1409 auto singleResMap = AffineMap::get(composedMap.getNumDims(), 1410 composedMap.getNumSymbols(), resultExpr); 1411 sliceOps->push_back(builder.create<AffineApplyOp>( 1412 opInst->getLoc(), singleResMap, composedOpOperands)); 1413 } 1414 1415 // Construct the new operands that include the results from the composed 1416 // affine apply op above instead of existing ones (subOperands). So, they 1417 // differ from opInst's operands only for those operands in 'subOperands', for 1418 // which they will be replaced by the corresponding one from 'sliceOps'. 1419 SmallVector<Value, 4> newOperands(opInst->getOperands()); 1420 for (unsigned i = 0, e = newOperands.size(); i < e; i++) { 1421 // Replace the subOperands from among the new operands. 1422 unsigned j, f; 1423 for (j = 0, f = subOperands.size(); j < f; j++) { 1424 if (newOperands[i] == subOperands[j]) 1425 break; 1426 } 1427 if (j < subOperands.size()) { 1428 newOperands[i] = (*sliceOps)[j]; 1429 } 1430 } 1431 for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) { 1432 opInst->setOperand(idx, newOperands[idx]); 1433 } 1434 } 1435 1436 /// Enum to set patterns of affine expr in tiled-layout map. 1437 /// TileFloorDiv: <dim expr> div <tile size> 1438 /// TileMod: <dim expr> mod <tile size> 1439 /// TileNone: None of the above 1440 /// Example: 1441 /// #tiled_2d_128x256 = affine_map<(d0, d1) 1442 /// -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)> 1443 /// "d0 div 128" and "d1 div 256" ==> TileFloorDiv 1444 /// "d0 mod 128" and "d1 mod 256" ==> TileMod 1445 enum TileExprPattern { TileFloorDiv, TileMod, TileNone }; 1446 1447 /// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions 1448 /// being floordiv'ed by respective tile sizes appeare in a mod with the same 1449 /// tile sizes, and no other expression involves those k dimensions. This 1450 /// function stores a vector of tuples (`tileSizePos`) including AffineExpr for 1451 /// tile size, positions of corresponding `floordiv` and `mod`. If it is not a 1452 /// tiled layout, an empty vector is returned. 1453 static LogicalResult getTileSizePos( 1454 AffineMap map, 1455 SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) { 1456 // Create `floordivExprs` which is a vector of tuples including LHS and RHS of 1457 // `floordiv` and its position in `map` output. 1458 // Example: #tiled_2d_128x256 = affine_map<(d0, d1) 1459 // -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)> 1460 // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}. 1461 SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs; 1462 unsigned pos = 0; 1463 for (AffineExpr expr : map.getResults()) { 1464 if (expr.getKind() == AffineExprKind::FloorDiv) { 1465 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 1466 if (binaryExpr.getRHS().isa<AffineConstantExpr>()) 1467 floordivExprs.emplace_back( 1468 std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos)); 1469 } 1470 pos++; 1471 } 1472 // Not tiled layout if `floordivExprs` is empty. 1473 if (floordivExprs.empty()) { 1474 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{}; 1475 return success(); 1476 } 1477 1478 // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is 1479 // not tiled layout. 1480 for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) { 1481 AffineExpr floordivExprLHS = std::get<0>(fexpr); 1482 AffineExpr floordivExprRHS = std::get<1>(fexpr); 1483 unsigned floordivPos = std::get<2>(fexpr); 1484 1485 // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS 1486 // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used 1487 // other expr, the map is not tiled layout. Example of non tiled layout: 1488 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)> 1489 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)> 1490 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod 1491 // 256)> 1492 bool found = false; 1493 pos = 0; 1494 for (AffineExpr expr : map.getResults()) { 1495 bool notTiled = false; 1496 if (pos != floordivPos) { 1497 expr.walk([&](AffineExpr e) { 1498 if (e == floordivExprLHS) { 1499 if (expr.getKind() == AffineExprKind::Mod) { 1500 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 1501 // If LHS and RHS of `mod` are the same with those of floordiv. 1502 if (floordivExprLHS == binaryExpr.getLHS() && 1503 floordivExprRHS == binaryExpr.getRHS()) { 1504 // Save tile size (RHS of `mod`), and position of `floordiv` and 1505 // `mod` if same expr with `mod` is not found yet. 1506 if (!found) { 1507 tileSizePos.emplace_back( 1508 std::make_tuple(binaryExpr.getRHS(), floordivPos, pos)); 1509 found = true; 1510 } else { 1511 // Non tiled layout: Have multilpe `mod` with the same LHS. 1512 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 1513 // mod 256, d2 mod 256)> 1514 notTiled = true; 1515 } 1516 } else { 1517 // Non tiled layout: RHS of `mod` is different from `floordiv`. 1518 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 1519 // mod 128)> 1520 notTiled = true; 1521 } 1522 } else { 1523 // Non tiled layout: LHS is the same, but not `mod`. 1524 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 1525 // floordiv 256)> 1526 notTiled = true; 1527 } 1528 } 1529 }); 1530 } 1531 if (notTiled) { 1532 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{}; 1533 return success(); 1534 } 1535 pos++; 1536 } 1537 } 1538 return success(); 1539 } 1540 1541 /// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic 1542 /// after normalization. Dimensions that include dynamic dimensions in the map 1543 /// output will become dynamic dimensions. Return true if `dim` is dynamic 1544 /// dimension. 1545 /// 1546 /// Example: 1547 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)> 1548 /// 1549 /// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic. 1550 /// memref<4x?xf32, #map0> ==> memref<4x?x?xf32> 1551 static bool 1552 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap, 1553 SmallVectorImpl<unsigned> &inMemrefTypeDynDims, 1554 MLIRContext *context) { 1555 bool isDynamicDim = false; 1556 AffineExpr expr = layoutMap.getResults()[dim]; 1557 // Check if affine expr of the dimension includes dynamic dimension of input 1558 // memrefType. 1559 expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) { 1560 if (e.isa<AffineDimExpr>()) { 1561 for (unsigned dm : inMemrefTypeDynDims) { 1562 if (e == getAffineDimExpr(dm, context)) { 1563 isDynamicDim = true; 1564 } 1565 } 1566 } 1567 }); 1568 return isDynamicDim; 1569 } 1570 1571 /// Create affine expr to calculate dimension size for a tiled-layout map. 1572 static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput, 1573 TileExprPattern pat) { 1574 // Create map output for the patterns. 1575 // "floordiv <tile size>" ==> "ceildiv <tile size>" 1576 // "mod <tile size>" ==> "<tile size>" 1577 AffineExpr newMapOutput; 1578 AffineBinaryOpExpr binaryExpr = nullptr; 1579 switch (pat) { 1580 case TileExprPattern::TileMod: 1581 binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>(); 1582 newMapOutput = binaryExpr.getRHS(); 1583 break; 1584 case TileExprPattern::TileFloorDiv: 1585 binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>(); 1586 newMapOutput = getAffineBinaryOpExpr( 1587 AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS()); 1588 break; 1589 default: 1590 newMapOutput = oldMapOutput; 1591 } 1592 return newMapOutput; 1593 } 1594 1595 /// Create new maps to calculate each dimension size of `newMemRefType`, and 1596 /// create `newDynamicSizes` from them by using AffineApplyOp. 1597 /// 1598 /// Steps for normalizing dynamic memrefs for a tiled layout map 1599 /// Example: 1600 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)> 1601 /// %0 = dim %arg0, %c1 :memref<4x?xf32> 1602 /// %1 = alloc(%0) : memref<4x?xf32, #map0> 1603 /// 1604 /// (Before this function) 1605 /// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only 1606 /// single layout map is supported. 1607 /// 1608 /// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It 1609 /// is memref<4x?x?xf32> in the above example. 1610 /// 1611 /// (In this function) 1612 /// 3. Create new maps to calculate each dimension of the normalized memrefType 1613 /// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the 1614 /// dimension size can be calculated by replacing "floordiv <tile size>" with 1615 /// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>". 1616 /// - New map in the above example 1617 /// #map0 = affine_map<(d0, d1) -> (d0)> 1618 /// #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)> 1619 /// #map2 = affine_map<(d0, d1) -> (32)> 1620 /// 1621 /// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp 1622 /// is used in dynamicSizes of new AllocOp. 1623 /// %0 = dim %arg0, %c1 : memref<4x?xf32> 1624 /// %c4 = arith.constant 4 : index 1625 /// %1 = affine.apply #map1(%c4, %0) 1626 /// %2 = affine.apply #map2(%c4, %0) 1627 static void createNewDynamicSizes(MemRefType oldMemRefType, 1628 MemRefType newMemRefType, AffineMap map, 1629 memref::AllocOp *allocOp, OpBuilder b, 1630 SmallVectorImpl<Value> &newDynamicSizes) { 1631 // Create new input for AffineApplyOp. 1632 SmallVector<Value, 4> inAffineApply; 1633 ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape(); 1634 unsigned dynIdx = 0; 1635 for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) { 1636 if (oldMemRefShape[d] < 0) { 1637 // Use dynamicSizes of allocOp for dynamic dimension. 1638 inAffineApply.emplace_back(allocOp->getDynamicSizes()[dynIdx]); 1639 dynIdx++; 1640 } else { 1641 // Create ConstantOp for static dimension. 1642 Attribute constantAttr = 1643 b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]); 1644 inAffineApply.emplace_back( 1645 b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr)); 1646 } 1647 } 1648 1649 // Create new map to calculate each dimension size of new memref for each 1650 // original map output. Only for dynamic dimesion of `newMemRefType`. 1651 unsigned newDimIdx = 0; 1652 ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape(); 1653 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos; 1654 (void)getTileSizePos(map, tileSizePos); 1655 for (AffineExpr expr : map.getResults()) { 1656 if (newMemRefShape[newDimIdx] < 0) { 1657 // Create new maps to calculate each dimension size of new memref. 1658 enum TileExprPattern pat = TileExprPattern::TileNone; 1659 for (auto pos : tileSizePos) { 1660 if (newDimIdx == std::get<1>(pos)) 1661 pat = TileExprPattern::TileFloorDiv; 1662 else if (newDimIdx == std::get<2>(pos)) 1663 pat = TileExprPattern::TileMod; 1664 } 1665 AffineExpr newMapOutput = createDimSizeExprForTiledLayout(expr, pat); 1666 AffineMap newMap = 1667 AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput); 1668 Value affineApp = 1669 b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply); 1670 newDynamicSizes.emplace_back(affineApp); 1671 } 1672 newDimIdx++; 1673 } 1674 } 1675 1676 // TODO: Currently works for static memrefs with a single layout map. 1677 LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) { 1678 MemRefType memrefType = allocOp->getType(); 1679 OpBuilder b(*allocOp); 1680 1681 // Fetch a new memref type after normalizing the old memref to have an 1682 // identity map layout. 1683 MemRefType newMemRefType = 1684 normalizeMemRefType(memrefType, b, allocOp->getSymbolOperands().size()); 1685 if (newMemRefType == memrefType) 1686 // Either memrefType already had an identity map or the map couldn't be 1687 // transformed to an identity map. 1688 return failure(); 1689 1690 Value oldMemRef = allocOp->getResult(); 1691 1692 SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands()); 1693 AffineMap layoutMap = memrefType.getLayout().getAffineMap(); 1694 memref::AllocOp newAlloc; 1695 // Check if `layoutMap` is a tiled layout. Only single layout map is 1696 // supported for normalizing dynamic memrefs. 1697 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos; 1698 (void)getTileSizePos(layoutMap, tileSizePos); 1699 if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) { 1700 MemRefType oldMemRefType = oldMemRef.getType().cast<MemRefType>(); 1701 SmallVector<Value, 4> newDynamicSizes; 1702 createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b, 1703 newDynamicSizes); 1704 // Add the new dynamic sizes in new AllocOp. 1705 newAlloc = 1706 b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType, 1707 newDynamicSizes, allocOp->getAlignmentAttr()); 1708 } else { 1709 newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType, 1710 allocOp->getAlignmentAttr()); 1711 } 1712 // Replace all uses of the old memref. 1713 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, 1714 /*extraIndices=*/{}, 1715 /*indexRemap=*/layoutMap, 1716 /*extraOperands=*/{}, 1717 /*symbolOperands=*/symbolOperands, 1718 /*domOpFilter=*/nullptr, 1719 /*postDomOpFilter=*/nullptr, 1720 /*allowNonDereferencingOps=*/true))) { 1721 // If it failed (due to escapes for example), bail out. 1722 newAlloc.erase(); 1723 return failure(); 1724 } 1725 // Replace any uses of the original alloc op and erase it. All remaining uses 1726 // have to be dealloc's; RAMUW above would've failed otherwise. 1727 assert(llvm::all_of(oldMemRef.getUsers(), [](Operation *op) { 1728 return isa<memref::DeallocOp>(op); 1729 })); 1730 oldMemRef.replaceAllUsesWith(newAlloc); 1731 allocOp->erase(); 1732 return success(); 1733 } 1734 1735 MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b, 1736 unsigned numSymbolicOperands) { 1737 unsigned rank = memrefType.getRank(); 1738 if (rank == 0) 1739 return memrefType; 1740 1741 if (memrefType.getLayout().isIdentity()) { 1742 // Either no maps is associated with this memref or this memref has 1743 // a trivial (identity) map. 1744 return memrefType; 1745 } 1746 AffineMap layoutMap = memrefType.getLayout().getAffineMap(); 1747 1748 // We don't do any checks for one-to-one'ness; we assume that it is 1749 // one-to-one. 1750 1751 // Normalize only static memrefs and dynamic memrefs with a tiled-layout map 1752 // for now. 1753 // TODO: Normalize the other types of dynamic memrefs. 1754 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos; 1755 (void)getTileSizePos(layoutMap, tileSizePos); 1756 if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty()) 1757 return memrefType; 1758 1759 // We have a single map that is not an identity map. Create a new memref 1760 // with the right shape and an identity layout map. 1761 ArrayRef<int64_t> shape = memrefType.getShape(); 1762 // FlatAffineValueConstraint may later on use symbolicOperands. 1763 FlatAffineValueConstraints fac(rank, numSymbolicOperands); 1764 SmallVector<unsigned, 4> memrefTypeDynDims; 1765 for (unsigned d = 0; d < rank; ++d) { 1766 // Use constraint system only in static dimensions. 1767 if (shape[d] > 0) { 1768 fac.addBound(IntegerPolyhedron::LB, d, 0); 1769 fac.addBound(IntegerPolyhedron::UB, d, shape[d] - 1); 1770 } else { 1771 memrefTypeDynDims.emplace_back(d); 1772 } 1773 } 1774 // We compose this map with the original index (logical) space to derive 1775 // the upper bounds for the new index space. 1776 unsigned newRank = layoutMap.getNumResults(); 1777 if (failed(fac.composeMatchingMap(layoutMap))) 1778 return memrefType; 1779 // TODO: Handle semi-affine maps. 1780 // Project out the old data dimensions. 1781 fac.projectOut(newRank, fac.getNumVars() - newRank - fac.getNumLocalVars()); 1782 SmallVector<int64_t, 4> newShape(newRank); 1783 for (unsigned d = 0; d < newRank; ++d) { 1784 // Check if each dimension of normalized memrefType is dynamic. 1785 bool isDynDim = isNormalizedMemRefDynamicDim( 1786 d, layoutMap, memrefTypeDynDims, b.getContext()); 1787 if (isDynDim) { 1788 newShape[d] = -1; 1789 } else { 1790 // The lower bound for the shape is always zero. 1791 auto ubConst = fac.getConstantBound(IntegerPolyhedron::UB, d); 1792 // For a static memref and an affine map with no symbols, this is 1793 // always bounded. 1794 assert(ubConst && "should always have an upper bound"); 1795 if (ubConst.value() < 0) 1796 // This is due to an invalid map that maps to a negative space. 1797 return memrefType; 1798 // If dimension of new memrefType is dynamic, the value is -1. 1799 newShape[d] = ubConst.value() + 1; 1800 } 1801 } 1802 1803 // Create the new memref type after trivializing the old layout map. 1804 MemRefType newMemRefType = 1805 MemRefType::Builder(memrefType) 1806 .setShape(newShape) 1807 .setLayout(AffineMapAttr::get(b.getMultiDimIdentityMap(newRank))); 1808 1809 return newMemRefType; 1810 } 1811