1 //===- Utils.cpp ---- Utilities for affine dialect transformation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous transformation utilities for the Affine 10 // dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/Utils.h" 15 16 #include "mlir/Dialect/Affine/Analysis/Utils.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/IR/AffineExprVisitor.h" 22 #include "mlir/IR/BlockAndValueMapping.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/IR/IntegerSet.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 27 #define DEBUG_TYPE "affine-utils" 28 29 using namespace mlir; 30 31 namespace { 32 /// Visit affine expressions recursively and build the sequence of operations 33 /// that correspond to it. Visitation functions return an Value of the 34 /// expression subtree they visited or `nullptr` on error. 35 class AffineApplyExpander 36 : public AffineExprVisitor<AffineApplyExpander, Value> { 37 public: 38 /// This internal class expects arguments to be non-null, checks must be 39 /// performed at the call site. 40 AffineApplyExpander(OpBuilder &builder, ValueRange dimValues, 41 ValueRange symbolValues, Location loc) 42 : builder(builder), dimValues(dimValues), symbolValues(symbolValues), 43 loc(loc) {} 44 45 template <typename OpTy> 46 Value buildBinaryExpr(AffineBinaryOpExpr expr) { 47 auto lhs = visit(expr.getLHS()); 48 auto rhs = visit(expr.getRHS()); 49 if (!lhs || !rhs) 50 return nullptr; 51 auto op = builder.create<OpTy>(loc, lhs, rhs); 52 return op.getResult(); 53 } 54 55 Value visitAddExpr(AffineBinaryOpExpr expr) { 56 return buildBinaryExpr<arith::AddIOp>(expr); 57 } 58 59 Value visitMulExpr(AffineBinaryOpExpr expr) { 60 return buildBinaryExpr<arith::MulIOp>(expr); 61 } 62 63 /// Euclidean modulo operation: negative RHS is not allowed. 64 /// Remainder of the euclidean integer division is always non-negative. 65 /// 66 /// Implemented as 67 /// 68 /// a mod b = 69 /// let remainder = srem a, b; 70 /// negative = a < 0 in 71 /// select negative, remainder + b, remainder. 72 Value visitModExpr(AffineBinaryOpExpr expr) { 73 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 74 if (!rhsConst) { 75 emitError( 76 loc, 77 "semi-affine expressions (modulo by non-const) are not supported"); 78 return nullptr; 79 } 80 if (rhsConst.getValue() <= 0) { 81 emitError(loc, "modulo by non-positive value is not supported"); 82 return nullptr; 83 } 84 85 auto lhs = visit(expr.getLHS()); 86 auto rhs = visit(expr.getRHS()); 87 assert(lhs && rhs && "unexpected affine expr lowering failure"); 88 89 Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs); 90 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 91 Value isRemainderNegative = builder.create<arith::CmpIOp>( 92 loc, arith::CmpIPredicate::slt, remainder, zeroCst); 93 Value correctedRemainder = 94 builder.create<arith::AddIOp>(loc, remainder, rhs); 95 Value result = builder.create<arith::SelectOp>( 96 loc, isRemainderNegative, correctedRemainder, remainder); 97 return result; 98 } 99 100 /// Floor division operation (rounds towards negative infinity). 101 /// 102 /// For positive divisors, it can be implemented without branching and with a 103 /// single division operation as 104 /// 105 /// a floordiv b = 106 /// let negative = a < 0 in 107 /// let absolute = negative ? -a - 1 : a in 108 /// let quotient = absolute / b in 109 /// negative ? -quotient - 1 : quotient 110 Value visitFloorDivExpr(AffineBinaryOpExpr expr) { 111 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 112 if (!rhsConst) { 113 emitError( 114 loc, 115 "semi-affine expressions (division by non-const) are not supported"); 116 return nullptr; 117 } 118 if (rhsConst.getValue() <= 0) { 119 emitError(loc, "division by non-positive value is not supported"); 120 return nullptr; 121 } 122 123 auto lhs = visit(expr.getLHS()); 124 auto rhs = visit(expr.getRHS()); 125 assert(lhs && rhs && "unexpected affine expr lowering failure"); 126 127 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 128 Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1); 129 Value negative = builder.create<arith::CmpIOp>( 130 loc, arith::CmpIPredicate::slt, lhs, zeroCst); 131 Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs); 132 Value dividend = 133 builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs); 134 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs); 135 Value correctedQuotient = 136 builder.create<arith::SubIOp>(loc, noneCst, quotient); 137 Value result = builder.create<arith::SelectOp>(loc, negative, 138 correctedQuotient, quotient); 139 return result; 140 } 141 142 /// Ceiling division operation (rounds towards positive infinity). 143 /// 144 /// For positive divisors, it can be implemented without branching and with a 145 /// single division operation as 146 /// 147 /// a ceildiv b = 148 /// let negative = a <= 0 in 149 /// let absolute = negative ? -a : a - 1 in 150 /// let quotient = absolute / b in 151 /// negative ? -quotient : quotient + 1 152 Value visitCeilDivExpr(AffineBinaryOpExpr expr) { 153 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 154 if (!rhsConst) { 155 emitError(loc) << "semi-affine expressions (division by non-const) are " 156 "not supported"; 157 return nullptr; 158 } 159 if (rhsConst.getValue() <= 0) { 160 emitError(loc, "division by non-positive value is not supported"); 161 return nullptr; 162 } 163 auto lhs = visit(expr.getLHS()); 164 auto rhs = visit(expr.getRHS()); 165 assert(lhs && rhs && "unexpected affine expr lowering failure"); 166 167 Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); 168 Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1); 169 Value nonPositive = builder.create<arith::CmpIOp>( 170 loc, arith::CmpIPredicate::sle, lhs, zeroCst); 171 Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs); 172 Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst); 173 Value dividend = 174 builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented); 175 Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs); 176 Value negatedQuotient = 177 builder.create<arith::SubIOp>(loc, zeroCst, quotient); 178 Value incrementedQuotient = 179 builder.create<arith::AddIOp>(loc, quotient, oneCst); 180 Value result = builder.create<arith::SelectOp>( 181 loc, nonPositive, negatedQuotient, incrementedQuotient); 182 return result; 183 } 184 185 Value visitConstantExpr(AffineConstantExpr expr) { 186 auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue()); 187 return op.getResult(); 188 } 189 190 Value visitDimExpr(AffineDimExpr expr) { 191 assert(expr.getPosition() < dimValues.size() && 192 "affine dim position out of range"); 193 return dimValues[expr.getPosition()]; 194 } 195 196 Value visitSymbolExpr(AffineSymbolExpr expr) { 197 assert(expr.getPosition() < symbolValues.size() && 198 "symbol dim position out of range"); 199 return symbolValues[expr.getPosition()]; 200 } 201 202 private: 203 OpBuilder &builder; 204 ValueRange dimValues; 205 ValueRange symbolValues; 206 207 Location loc; 208 }; 209 } // namespace 210 211 /// Create a sequence of operations that implement the `expr` applied to the 212 /// given dimension and symbol values. 213 mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc, 214 AffineExpr expr, ValueRange dimValues, 215 ValueRange symbolValues) { 216 return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); 217 } 218 219 /// Create a sequence of operations that implement the `affineMap` applied to 220 /// the given `operands` (as it it were an AffineApplyOp). 221 Optional<SmallVector<Value, 8>> mlir::expandAffineMap(OpBuilder &builder, 222 Location loc, 223 AffineMap affineMap, 224 ValueRange operands) { 225 auto numDims = affineMap.getNumDims(); 226 auto expanded = llvm::to_vector<8>( 227 llvm::map_range(affineMap.getResults(), 228 [numDims, &builder, loc, operands](AffineExpr expr) { 229 return expandAffineExpr(builder, loc, expr, 230 operands.take_front(numDims), 231 operands.drop_front(numDims)); 232 })); 233 if (llvm::all_of(expanded, [](Value v) { return v; })) 234 return expanded; 235 return None; 236 } 237 238 /// Promotes the `then` or the `else` block of `ifOp` (depending on whether 239 /// `elseBlock` is false or true) into `ifOp`'s containing block, and discards 240 /// the rest of the op. 241 static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) { 242 if (elseBlock) 243 assert(ifOp.hasElse() && "else block expected"); 244 245 Block *destBlock = ifOp->getBlock(); 246 Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock(); 247 destBlock->getOperations().splice( 248 Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(), 249 std::prev(srcBlock->end())); 250 ifOp.erase(); 251 } 252 253 /// Returns the outermost affine.for/parallel op that the `ifOp` is invariant 254 /// on. The `ifOp` could be hoisted and placed right before such an operation. 255 /// This method assumes that the ifOp has been canonicalized (to be correct and 256 /// effective). 257 static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) { 258 // Walk up the parents past all for op that this conditional is invariant on. 259 auto ifOperands = ifOp.getOperands(); 260 auto *res = ifOp.getOperation(); 261 while (!isa<FuncOp>(res->getParentOp())) { 262 auto *parentOp = res->getParentOp(); 263 if (auto forOp = dyn_cast<AffineForOp>(parentOp)) { 264 if (llvm::is_contained(ifOperands, forOp.getInductionVar())) 265 break; 266 } else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) { 267 for (auto iv : parallelOp.getIVs()) 268 if (llvm::is_contained(ifOperands, iv)) 269 break; 270 } else if (!isa<AffineIfOp>(parentOp)) { 271 // Won't walk up past anything other than affine.for/if ops. 272 break; 273 } 274 // You can always hoist up past any affine.if ops. 275 res = parentOp; 276 } 277 return res; 278 } 279 280 /// A helper for the mechanics of mlir::hoistAffineIfOp. Hoists `ifOp` just over 281 /// `hoistOverOp`. Returns the new hoisted op if any hoisting happened, 282 /// otherwise the same `ifOp`. 283 static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) { 284 // No hoisting to do. 285 if (hoistOverOp == ifOp) 286 return ifOp; 287 288 // Create the hoisted 'if' first. Then, clone the op we are hoisting over for 289 // the else block. Then drop the else block of the original 'if' in the 'then' 290 // branch while promoting its then block, and analogously drop the 'then' 291 // block of the original 'if' from the 'else' branch while promoting its else 292 // block. 293 BlockAndValueMapping operandMap; 294 OpBuilder b(hoistOverOp); 295 auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(), 296 ifOp.getOperands(), 297 /*elseBlock=*/true); 298 299 // Create a clone of hoistOverOp to use for the else branch of the hoisted 300 // conditional. The else block may get optimized away if empty. 301 Operation *hoistOverOpClone = nullptr; 302 // We use this unique name to identify/find `ifOp`'s clone in the else 303 // version. 304 StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting"); 305 operandMap.clear(); 306 b.setInsertionPointAfter(hoistOverOp); 307 // We'll set an attribute to identify this op in a clone of this sub-tree. 308 ifOp->setAttr(idForIfOp, b.getBoolAttr(true)); 309 hoistOverOpClone = b.clone(*hoistOverOp, operandMap); 310 311 // Promote the 'then' block of the original affine.if in the then version. 312 promoteIfBlock(ifOp, /*elseBlock=*/false); 313 314 // Move the then version to the hoisted if op's 'then' block. 315 auto *thenBlock = hoistedIfOp.getThenBlock(); 316 thenBlock->getOperations().splice(thenBlock->begin(), 317 hoistOverOp->getBlock()->getOperations(), 318 Block::iterator(hoistOverOp)); 319 320 // Find the clone of the original affine.if op in the else version. 321 AffineIfOp ifCloneInElse; 322 hoistOverOpClone->walk([&](AffineIfOp ifClone) { 323 if (!ifClone->getAttr(idForIfOp)) 324 return WalkResult::advance(); 325 ifCloneInElse = ifClone; 326 return WalkResult::interrupt(); 327 }); 328 assert(ifCloneInElse && "if op clone should exist"); 329 // For the else block, promote the else block of the original 'if' if it had 330 // one; otherwise, the op itself is to be erased. 331 if (!ifCloneInElse.hasElse()) 332 ifCloneInElse.erase(); 333 else 334 promoteIfBlock(ifCloneInElse, /*elseBlock=*/true); 335 336 // Move the else version into the else block of the hoisted if op. 337 auto *elseBlock = hoistedIfOp.getElseBlock(); 338 elseBlock->getOperations().splice( 339 elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(), 340 Block::iterator(hoistOverOpClone)); 341 342 return hoistedIfOp; 343 } 344 345 LogicalResult 346 mlir::affineParallelize(AffineForOp forOp, 347 ArrayRef<LoopReduction> parallelReductions) { 348 // Fail early if there are iter arguments that are not reductions. 349 unsigned numReductions = parallelReductions.size(); 350 if (numReductions != forOp.getNumIterOperands()) 351 return failure(); 352 353 Location loc = forOp.getLoc(); 354 OpBuilder outsideBuilder(forOp); 355 AffineMap lowerBoundMap = forOp.getLowerBoundMap(); 356 ValueRange lowerBoundOperands = forOp.getLowerBoundOperands(); 357 AffineMap upperBoundMap = forOp.getUpperBoundMap(); 358 ValueRange upperBoundOperands = forOp.getUpperBoundOperands(); 359 360 // Creating empty 1-D affine.parallel op. 361 auto reducedValues = llvm::to_vector<4>(llvm::map_range( 362 parallelReductions, [](const LoopReduction &red) { return red.value; })); 363 auto reductionKinds = llvm::to_vector<4>(llvm::map_range( 364 parallelReductions, [](const LoopReduction &red) { return red.kind; })); 365 AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>( 366 loc, ValueRange(reducedValues).getTypes(), reductionKinds, 367 llvm::makeArrayRef(lowerBoundMap), lowerBoundOperands, 368 llvm::makeArrayRef(upperBoundMap), upperBoundOperands, 369 llvm::makeArrayRef(forOp.getStep())); 370 // Steal the body of the old affine for op. 371 newPloop.region().takeBody(forOp.region()); 372 Operation *yieldOp = &newPloop.getBody()->back(); 373 374 // Handle the initial values of reductions because the parallel loop always 375 // starts from the neutral value. 376 SmallVector<Value> newResults; 377 newResults.reserve(numReductions); 378 for (unsigned i = 0; i < numReductions; ++i) { 379 Value init = forOp.getIterOperands()[i]; 380 // This works because we are only handling single-op reductions at the 381 // moment. A switch on reduction kind or a mechanism to collect operations 382 // participating in the reduction will be necessary for multi-op reductions. 383 Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp(); 384 assert(reductionOp && "yielded value is expected to be produced by an op"); 385 outsideBuilder.getInsertionBlock()->getOperations().splice( 386 outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(), 387 reductionOp); 388 reductionOp->setOperands({init, newPloop->getResult(i)}); 389 forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0)); 390 } 391 392 // Update the loop terminator to yield reduced values bypassing the reduction 393 // operation itself (now moved outside of the loop) and erase the block 394 // arguments that correspond to reductions. Note that the loop always has one 395 // "main" induction variable whenc coming from a non-parallel for. 396 unsigned numIVs = 1; 397 yieldOp->setOperands(reducedValues); 398 newPloop.getBody()->eraseArguments( 399 llvm::to_vector<4>(llvm::seq<unsigned>(numIVs, numReductions + numIVs))); 400 401 forOp.erase(); 402 return success(); 403 } 404 405 // Returns success if any hoisting happened. 406 LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { 407 // Bail out early if the ifOp returns a result. TODO: Consider how to 408 // properly support this case. 409 if (ifOp.getNumResults() != 0) 410 return failure(); 411 412 // Apply canonicalization patterns and folding - this is necessary for the 413 // hoisting check to be correct (operands should be composed), and to be more 414 // effective (no unused operands). Since the pattern rewriter's folding is 415 // entangled with application of patterns, we may fold/end up erasing the op, 416 // in which case we return with `folded` being set. 417 RewritePatternSet patterns(ifOp.getContext()); 418 AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); 419 bool erased; 420 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 421 (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased); 422 if (erased) { 423 if (folded) 424 *folded = true; 425 return failure(); 426 } 427 if (folded) 428 *folded = false; 429 430 // The folding above should have ensured this, but the affine.if's 431 // canonicalization is missing composition of affine.applys into it. 432 assert(llvm::all_of(ifOp.getOperands(), 433 [](Value v) { 434 return isTopLevelValue(v) || isForInductionVar(v); 435 }) && 436 "operands not composed"); 437 438 // We are going hoist as high as possible. 439 // TODO: this could be customized in the future. 440 auto *hoistOverOp = getOutermostInvariantForOp(ifOp); 441 442 AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp); 443 // Nothing to hoist over. 444 if (hoistedIfOp == ifOp) 445 return failure(); 446 447 // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up 448 // a sequence of affine.fors that are all perfectly nested). 449 (void)applyPatternsAndFoldGreedily( 450 hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(), 451 frozenPatterns); 452 453 return success(); 454 } 455 456 // Return the min expr after replacing the given dim. 457 AffineExpr mlir::substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min, 458 AffineExpr max, bool positivePath) { 459 if (e == dim) 460 return positivePath ? min : max; 461 if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) { 462 AffineExpr lhs = bin.getLHS(); 463 AffineExpr rhs = bin.getRHS(); 464 if (bin.getKind() == mlir::AffineExprKind::Add) 465 return substWithMin(lhs, dim, min, max, positivePath) + 466 substWithMin(rhs, dim, min, max, positivePath); 467 468 auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>(); 469 auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>(); 470 if (c1 && c1.getValue() < 0) 471 return getAffineBinaryOpExpr( 472 bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath)); 473 if (c2 && c2.getValue() < 0) 474 return getAffineBinaryOpExpr( 475 bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2); 476 return getAffineBinaryOpExpr( 477 bin.getKind(), substWithMin(lhs, dim, min, max, positivePath), 478 substWithMin(rhs, dim, min, max, positivePath)); 479 } 480 return e; 481 } 482 483 void mlir::normalizeAffineParallel(AffineParallelOp op) { 484 // Loops with min/max in bounds are not normalized at the moment. 485 if (op.hasMinMaxBounds()) 486 return; 487 488 AffineMap lbMap = op.lowerBoundsMap(); 489 SmallVector<int64_t, 8> steps = op.getSteps(); 490 // No need to do any work if the parallel op is already normalized. 491 bool isAlreadyNormalized = 492 llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) { 493 int64_t step = std::get<0>(tuple); 494 auto lbExpr = 495 std::get<1>(tuple).template dyn_cast<AffineConstantExpr>(); 496 return lbExpr && lbExpr.getValue() == 0 && step == 1; 497 }); 498 if (isAlreadyNormalized) 499 return; 500 501 AffineValueMap ranges; 502 AffineValueMap::difference(op.getUpperBoundsValueMap(), 503 op.getLowerBoundsValueMap(), &ranges); 504 auto builder = OpBuilder::atBlockBegin(op.getBody()); 505 auto zeroExpr = builder.getAffineConstantExpr(0); 506 SmallVector<AffineExpr, 8> lbExprs; 507 SmallVector<AffineExpr, 8> ubExprs; 508 for (unsigned i = 0, e = steps.size(); i < e; ++i) { 509 int64_t step = steps[i]; 510 511 // Adjust the lower bound to be 0. 512 lbExprs.push_back(zeroExpr); 513 514 // Adjust the upper bound expression: 'range / step'. 515 AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step); 516 ubExprs.push_back(ubExpr); 517 518 // Adjust the corresponding IV: 'lb + i * step'. 519 BlockArgument iv = op.getBody()->getArgument(i); 520 AffineExpr lbExpr = lbMap.getResult(i); 521 unsigned nDims = lbMap.getNumDims(); 522 auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step; 523 auto map = AffineMap::get(/*dimCount=*/nDims + 1, 524 /*symbolCount=*/lbMap.getNumSymbols(), expr); 525 526 // Use an 'affine.apply' op that will be simplified later in subsequent 527 // canonicalizations. 528 OperandRange lbOperands = op.getLowerBoundsOperands(); 529 OperandRange dimOperands = lbOperands.take_front(nDims); 530 OperandRange symbolOperands = lbOperands.drop_front(nDims); 531 SmallVector<Value, 8> applyOperands{dimOperands}; 532 applyOperands.push_back(iv); 533 applyOperands.append(symbolOperands.begin(), symbolOperands.end()); 534 auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands); 535 iv.replaceAllUsesExcept(apply, apply); 536 } 537 538 SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1); 539 op.setSteps(newSteps); 540 auto newLowerMap = AffineMap::get( 541 /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext()); 542 op.setLowerBounds({}, newLowerMap); 543 auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(), 544 ubExprs, op.getContext()); 545 op.setUpperBounds(ranges.getOperands(), newUpperMap); 546 } 547 548 /// Normalizes affine.for ops. If the affine.for op has only a single iteration 549 /// only then it is simply promoted, else it is normalized in the traditional 550 /// way, by converting the lower bound to zero and loop step to one. The upper 551 /// bound is set to the trip count of the loop. For now, original loops must 552 /// have lower bound with a single result only. There is no such restriction on 553 /// upper bounds. 554 LogicalResult mlir::normalizeAffineFor(AffineForOp op) { 555 if (succeeded(promoteIfSingleIteration(op))) 556 return success(); 557 558 // Check if the forop is already normalized. 559 if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) && 560 (op.getStep() == 1)) 561 return success(); 562 563 // Check if the lower bound has a single result only. Loops with a max lower 564 // bound can't be normalized without additional support like 565 // affine.execute_region's. If the lower bound does not have a single result 566 // then skip this op. 567 if (op.getLowerBoundMap().getNumResults() != 1) 568 return failure(); 569 570 Location loc = op.getLoc(); 571 OpBuilder opBuilder(op); 572 int64_t origLoopStep = op.getStep(); 573 574 // Calculate upperBound for normalized loop. 575 SmallVector<Value, 4> ubOperands; 576 AffineBound lb = op.getLowerBound(); 577 AffineBound ub = op.getUpperBound(); 578 ubOperands.reserve(ub.getNumOperands() + lb.getNumOperands()); 579 AffineMap origLbMap = lb.getMap(); 580 AffineMap origUbMap = ub.getMap(); 581 582 // Add dimension operands from upper/lower bound. 583 for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j) 584 ubOperands.push_back(ub.getOperand(j)); 585 for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j) 586 ubOperands.push_back(lb.getOperand(j)); 587 588 // Add symbol operands from upper/lower bound. 589 for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j) 590 ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j)); 591 for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) 592 ubOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); 593 594 // Add original result expressions from lower/upper bound map. 595 SmallVector<AffineExpr, 1> origLbExprs(origLbMap.getResults().begin(), 596 origLbMap.getResults().end()); 597 SmallVector<AffineExpr, 2> origUbExprs(origUbMap.getResults().begin(), 598 origUbMap.getResults().end()); 599 SmallVector<AffineExpr, 4> newUbExprs; 600 601 // The original upperBound can have more than one result. For the new 602 // upperBound of this loop, take difference of all possible combinations of 603 // the ub results and lb result and ceildiv with the loop step. For e.g., 604 // 605 // affine.for %i1 = 0 to min affine_map<(d0)[] -> (d0 + 32, 1024)>(%i0) 606 // will have an upperBound map as, 607 // affine_map<(d0)[] -> (((d0 + 32) - 0) ceildiv 1, (1024 - 0) ceildiv 608 // 1)>(%i0) 609 // 610 // Insert all combinations of upper/lower bound results. 611 for (unsigned i = 0, e = origUbExprs.size(); i < e; ++i) { 612 newUbExprs.push_back( 613 (origUbExprs[i] - origLbExprs[0]).ceilDiv(origLoopStep)); 614 } 615 616 // Construct newUbMap. 617 AffineMap newUbMap = 618 AffineMap::get(origLbMap.getNumDims() + origUbMap.getNumDims(), 619 origLbMap.getNumSymbols() + origUbMap.getNumSymbols(), 620 newUbExprs, opBuilder.getContext()); 621 canonicalizeMapAndOperands(&newUbMap, &ubOperands); 622 623 // Normalize the loop. 624 op.setUpperBound(ubOperands, newUbMap); 625 op.setLowerBound({}, opBuilder.getConstantAffineMap(0)); 626 op.setStep(1); 627 628 // Calculate the Value of new loopIV. Create affine.apply for the value of 629 // the loopIV in normalized loop. 630 opBuilder.setInsertionPointToStart(op.getBody()); 631 SmallVector<Value, 4> lbOperands(lb.getOperands().begin(), 632 lb.getOperands().begin() + 633 lb.getMap().getNumDims()); 634 // Add an extra dim operand for loopIV. 635 lbOperands.push_back(op.getInductionVar()); 636 // Add symbol operands from lower bound. 637 for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j) 638 lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j)); 639 640 AffineExpr origIVExpr = opBuilder.getAffineDimExpr(lb.getMap().getNumDims()); 641 AffineExpr newIVExpr = origIVExpr * origLoopStep + origLbMap.getResult(0); 642 AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1, 643 origLbMap.getNumSymbols(), newIVExpr); 644 canonicalizeMapAndOperands(&ivMap, &lbOperands); 645 Operation *newIV = opBuilder.create<AffineApplyOp>(loc, ivMap, lbOperands); 646 op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); 647 return success(); 648 } 649 650 /// Ensure that all operations that could be executed after `start` 651 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path 652 /// between the operations) do not have the potential memory effect 653 /// `EffectType` on `memOp`. `memOp` is an operation that reads or writes to 654 /// a memref. For example, if `EffectType` is MemoryEffects::Write, this method 655 /// will check if there is no write to the memory between `start` and `memOp` 656 /// that would change the read within `memOp`. 657 template <typename EffectType, typename T> 658 static bool hasNoInterveningEffect(Operation *start, T memOp) { 659 Value memref = memOp.getMemRef(); 660 bool isOriginalAllocation = memref.getDefiningOp<memref::AllocaOp>() || 661 memref.getDefiningOp<memref::AllocOp>(); 662 663 // A boolean representing whether an intervening operation could have impacted 664 // memOp. 665 bool hasSideEffect = false; 666 667 // Check whether the effect on memOp can be caused by a given operation op. 668 std::function<void(Operation *)> checkOperation = [&](Operation *op) { 669 // If the effect has alreay been found, early exit, 670 if (hasSideEffect) 671 return; 672 673 if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) { 674 SmallVector<MemoryEffects::EffectInstance, 1> effects; 675 memEffect.getEffects(effects); 676 677 bool opMayHaveEffect = false; 678 for (auto effect : effects) { 679 // If op causes EffectType on a potentially aliasing location for 680 // memOp, mark as having the effect. 681 if (isa<EffectType>(effect.getEffect())) { 682 if (isOriginalAllocation && effect.getValue() && 683 (effect.getValue().getDefiningOp<memref::AllocaOp>() || 684 effect.getValue().getDefiningOp<memref::AllocOp>())) { 685 if (effect.getValue() != memref) 686 continue; 687 } 688 opMayHaveEffect = true; 689 break; 690 } 691 } 692 693 if (!opMayHaveEffect) 694 return; 695 696 // If the side effect comes from an affine read or write, try to 697 // prove the side effecting `op` cannot reach `memOp`. 698 if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { 699 MemRefAccess srcAccess(op); 700 MemRefAccess destAccess(memOp); 701 // Dependence analysis is only correct if both ops operate on the same 702 // memref. 703 if (srcAccess.memref == destAccess.memref) { 704 FlatAffineValueConstraints dependenceConstraints; 705 706 // Number of loops containing the start op and the ending operation. 707 unsigned minSurroundingLoops = 708 getNumCommonSurroundingLoops(*start, *memOp); 709 710 // Number of loops containing the operation `op` which has the 711 // potential memory side effect and can occur on a path between 712 // `start` and `memOp`. 713 unsigned nsLoops = getNumCommonSurroundingLoops(*op, *memOp); 714 715 // For ease, let's consider the case that `op` is a store and we're 716 // looking for other potential stores (e.g `op`) that overwrite memory 717 // after `start`, and before being read in `memOp`. In this case, we 718 // only need to consider other potential stores with depth > 719 // minSurrounding loops since `start` would overwrite any store with a 720 // smaller number of surrounding loops before. 721 unsigned d; 722 for (d = nsLoops + 1; d > minSurroundingLoops; d--) { 723 DependenceResult result = checkMemrefAccessDependence( 724 srcAccess, destAccess, d, &dependenceConstraints, 725 /*dependenceComponents=*/nullptr); 726 if (hasDependence(result)) { 727 hasSideEffect = true; 728 return; 729 } 730 } 731 732 // No side effect was seen, simply return. 733 return; 734 } 735 } 736 hasSideEffect = true; 737 return; 738 } 739 740 if (op->hasTrait<OpTrait::HasRecursiveSideEffects>()) { 741 // Recurse into the regions for this op and check whether the internal 742 // operations may have the side effect `EffectType` on memOp. 743 for (Region ®ion : op->getRegions()) 744 for (Block &block : region) 745 for (Operation &op : block) 746 checkOperation(&op); 747 return; 748 } 749 750 // Otherwise, conservatively assume generic operations have the effect 751 // on the operation 752 hasSideEffect = true; 753 }; 754 755 // Check all paths from ancestor op `parent` to the operation `to` for the 756 // effect. It is known that `to` must be contained within `parent`. 757 auto until = [&](Operation *parent, Operation *to) { 758 // TODO check only the paths from `parent` to `to`. 759 // Currently we fallback and check the entire parent op, rather than 760 // just the paths from the parent path, stopping after reaching `to`. 761 // This is conservatively correct, but could be made more aggressive. 762 assert(parent->isAncestor(to)); 763 checkOperation(parent); 764 }; 765 766 // Check for all paths from operation `from` to operation `untilOp` for the 767 // given memory effect. 768 std::function<void(Operation *, Operation *)> recur = 769 [&](Operation *from, Operation *untilOp) { 770 assert( 771 from->getParentRegion()->isAncestor(untilOp->getParentRegion()) && 772 "Checking for side effect between two operations without a common " 773 "ancestor"); 774 775 // If the operations are in different regions, recursively consider all 776 // path from `from` to the parent of `to` and all paths from the parent 777 // of `to` to `to`. 778 if (from->getParentRegion() != untilOp->getParentRegion()) { 779 recur(from, untilOp->getParentOp()); 780 until(untilOp->getParentOp(), untilOp); 781 return; 782 } 783 784 // Now, assuming that `from` and `to` exist in the same region, perform 785 // a CFG traversal to check all the relevant operations. 786 787 // Additional blocks to consider. 788 SmallVector<Block *, 2> todoBlocks; 789 { 790 // First consider the parent block of `from` an check all operations 791 // after `from`. 792 for (auto iter = ++from->getIterator(), end = from->getBlock()->end(); 793 iter != end && &*iter != untilOp; ++iter) { 794 checkOperation(&*iter); 795 } 796 797 // If the parent of `from` doesn't contain `to`, add the successors 798 // to the list of blocks to check. 799 if (untilOp->getBlock() != from->getBlock()) 800 for (Block *succ : from->getBlock()->getSuccessors()) 801 todoBlocks.push_back(succ); 802 } 803 804 SmallPtrSet<Block *, 4> done; 805 // Traverse the CFG until hitting `to`. 806 while (!todoBlocks.empty()) { 807 Block *blk = todoBlocks.pop_back_val(); 808 if (done.count(blk)) 809 continue; 810 done.insert(blk); 811 for (auto &op : *blk) { 812 if (&op == untilOp) 813 break; 814 checkOperation(&op); 815 if (&op == blk->getTerminator()) 816 for (Block *succ : blk->getSuccessors()) 817 todoBlocks.push_back(succ); 818 } 819 } 820 }; 821 recur(start, memOp); 822 return !hasSideEffect; 823 } 824 825 /// Attempt to eliminate loadOp by replacing it with a value stored into memory 826 /// which the load is guaranteed to retrieve. This check involves three 827 /// components: 1) The store and load must be on the same location 2) The store 828 /// must dominate (and therefore must always occur prior to) the load 3) No 829 /// other operations will overwrite the memory loaded between the given load 830 /// and store. If such a value exists, the replaced `loadOp` will be added to 831 /// `loadOpsToErase` and its memref will be added to `memrefsToErase`. 832 static LogicalResult forwardStoreToLoad( 833 AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase, 834 SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo) { 835 836 // The store op candidate for forwarding that satisfies all conditions 837 // to replace the load, if any. 838 Operation *lastWriteStoreOp = nullptr; 839 840 for (auto *user : loadOp.getMemRef().getUsers()) { 841 auto storeOp = dyn_cast<AffineWriteOpInterface>(user); 842 if (!storeOp) 843 continue; 844 MemRefAccess srcAccess(storeOp); 845 MemRefAccess destAccess(loadOp); 846 847 // 1. Check if the store and the load have mathematically equivalent 848 // affine access functions; this implies that they statically refer to the 849 // same single memref element. As an example this filters out cases like: 850 // store %A[%i0 + 1] 851 // load %A[%i0] 852 // store %A[%M] 853 // load %A[%N] 854 // Use the AffineValueMap difference based memref access equality checking. 855 if (srcAccess != destAccess) 856 continue; 857 858 // 2. The store has to dominate the load op to be candidate. 859 if (!domInfo.dominates(storeOp, loadOp)) 860 continue; 861 862 // 3. Ensure there is no intermediate operation which could replace the 863 // value in memory. 864 if (!hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp)) 865 continue; 866 867 // We now have a candidate for forwarding. 868 assert(lastWriteStoreOp == nullptr && 869 "multiple simulataneous replacement stores"); 870 lastWriteStoreOp = storeOp; 871 } 872 873 if (!lastWriteStoreOp) 874 return failure(); 875 876 // Perform the actual store to load forwarding. 877 Value storeVal = 878 cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore(); 879 // Check if 2 values have the same shape. This is needed for affine vector 880 // loads and stores. 881 if (storeVal.getType() != loadOp.getValue().getType()) 882 return failure(); 883 loadOp.getValue().replaceAllUsesWith(storeVal); 884 // Record the memref for a later sweep to optimize away. 885 memrefsToErase.insert(loadOp.getMemRef()); 886 // Record this to erase later. 887 loadOpsToErase.push_back(loadOp); 888 return success(); 889 } 890 891 // This attempts to find stores which have no impact on the final result. 892 // A writing op writeA will be eliminated if there exists an op writeB if 893 // 1) writeA and writeB have mathematically equivalent affine access functions. 894 // 2) writeB postdominates writeA. 895 // 3) There is no potential read between writeA and writeB. 896 static void findUnusedStore(AffineWriteOpInterface writeA, 897 SmallVectorImpl<Operation *> &opsToErase, 898 PostDominanceInfo &postDominanceInfo) { 899 900 for (Operation *user : writeA.getMemRef().getUsers()) { 901 // Only consider writing operations. 902 auto writeB = dyn_cast<AffineWriteOpInterface>(user); 903 if (!writeB) 904 continue; 905 906 // The operations must be distinct. 907 if (writeB == writeA) 908 continue; 909 910 // Both operations must lie in the same region. 911 if (writeB->getParentRegion() != writeA->getParentRegion()) 912 continue; 913 914 // Both operations must write to the same memory. 915 MemRefAccess srcAccess(writeB); 916 MemRefAccess destAccess(writeA); 917 918 if (srcAccess != destAccess) 919 continue; 920 921 // writeB must postdominate writeA. 922 if (!postDominanceInfo.postDominates(writeB, writeA)) 923 continue; 924 925 // There cannot be an operation which reads from memory between 926 // the two writes. 927 if (!hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB)) 928 continue; 929 930 opsToErase.push_back(writeA); 931 break; 932 } 933 } 934 935 // The load to load forwarding / redundant load elimination is similar to the 936 // store to load forwarding. 937 // loadA will be be replaced with loadB if: 938 // 1) loadA and loadB have mathematically equivalent affine access functions. 939 // 2) loadB dominates loadA. 940 // 3) There is no write between loadA and loadB. 941 static void loadCSE(AffineReadOpInterface loadA, 942 SmallVectorImpl<Operation *> &loadOpsToErase, 943 DominanceInfo &domInfo) { 944 SmallVector<AffineReadOpInterface, 4> loadCandidates; 945 for (auto *user : loadA.getMemRef().getUsers()) { 946 auto loadB = dyn_cast<AffineReadOpInterface>(user); 947 if (!loadB || loadB == loadA) 948 continue; 949 950 MemRefAccess srcAccess(loadB); 951 MemRefAccess destAccess(loadA); 952 953 // 1. The accesses have to be to the same location. 954 if (srcAccess != destAccess) { 955 continue; 956 } 957 958 // 2. The store has to dominate the load op to be candidate. 959 if (!domInfo.dominates(loadB, loadA)) 960 continue; 961 962 // 3. There is no write between loadA and loadB. 963 if (!hasNoInterveningEffect<MemoryEffects::Write>(loadB.getOperation(), 964 loadA)) 965 continue; 966 967 // Check if two values have the same shape. This is needed for affine vector 968 // loads. 969 if (loadB.getValue().getType() != loadA.getValue().getType()) 970 continue; 971 972 loadCandidates.push_back(loadB); 973 } 974 975 // Of the legal load candidates, use the one that dominates all others 976 // to minimize the subsequent need to loadCSE 977 Value loadB; 978 for (AffineReadOpInterface option : loadCandidates) { 979 if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) { 980 return depStore == option || 981 domInfo.dominates(option.getOperation(), 982 depStore.getOperation()); 983 })) { 984 loadB = option.getValue(); 985 break; 986 } 987 } 988 989 if (loadB) { 990 loadA.getValue().replaceAllUsesWith(loadB); 991 // Record this to erase later. 992 loadOpsToErase.push_back(loadA); 993 } 994 } 995 996 // The store to load forwarding and load CSE rely on three conditions: 997 // 998 // 1) store/load providing a replacement value and load being replaced need to 999 // have mathematically equivalent affine access functions (checked after full 1000 // composition of load/store operands); this implies that they access the same 1001 // single memref element for all iterations of the common surrounding loop, 1002 // 1003 // 2) the store/load op should dominate the load op, 1004 // 1005 // 3) no operation that may write to memory read by the load being replaced can 1006 // occur after executing the instruction (load or store) providing the 1007 // replacement value and before the load being replaced (thus potentially 1008 // allowing overwriting the memory read by the load). 1009 // 1010 // The above conditions are simple to check, sufficient, and powerful for most 1011 // cases in practice - they are sufficient, but not necessary --- since they 1012 // don't reason about loops that are guaranteed to execute at least once or 1013 // multiple sources to forward from. 1014 // 1015 // TODO: more forwarding can be done when support for 1016 // loop/conditional live-out SSA values is available. 1017 // TODO: do general dead store elimination for memref's. This pass 1018 // currently only eliminates the stores only if no other loads/uses (other 1019 // than dealloc) remain. 1020 // 1021 void mlir::affineScalarReplace(FuncOp f, DominanceInfo &domInfo, 1022 PostDominanceInfo &postDomInfo) { 1023 // Load op's whose results were replaced by those forwarded from stores. 1024 SmallVector<Operation *, 8> opsToErase; 1025 1026 // A list of memref's that are potentially dead / could be eliminated. 1027 SmallPtrSet<Value, 4> memrefsToErase; 1028 1029 // Walk all load's and perform store to load forwarding. 1030 f.walk([&](AffineReadOpInterface loadOp) { 1031 if (failed( 1032 forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) { 1033 loadCSE(loadOp, opsToErase, domInfo); 1034 } 1035 }); 1036 1037 // Erase all load op's whose results were replaced with store fwd'ed ones. 1038 for (auto *op : opsToErase) 1039 op->erase(); 1040 opsToErase.clear(); 1041 1042 // Walk all store's and perform unused store elimination 1043 f.walk([&](AffineWriteOpInterface storeOp) { 1044 findUnusedStore(storeOp, opsToErase, postDomInfo); 1045 }); 1046 // Erase all store op's which don't impact the program 1047 for (auto *op : opsToErase) 1048 op->erase(); 1049 1050 // Check if the store fwd'ed memrefs are now left with only stores and can 1051 // thus be completely deleted. Note: the canonicalize pass should be able 1052 // to do this as well, but we'll do it here since we collected these anyway. 1053 for (auto memref : memrefsToErase) { 1054 // If the memref hasn't been alloc'ed in this function, skip. 1055 Operation *defOp = memref.getDefiningOp(); 1056 if (!defOp || !isa<memref::AllocOp>(defOp)) 1057 // TODO: if the memref was returned by a 'call' operation, we 1058 // could still erase it if the call had no side-effects. 1059 continue; 1060 if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) { 1061 return !isa<AffineWriteOpInterface, memref::DeallocOp>(ownerOp); 1062 })) 1063 continue; 1064 1065 // Erase all stores, the dealloc, and the alloc on the memref. 1066 for (auto *user : llvm::make_early_inc_range(memref.getUsers())) 1067 user->erase(); 1068 defOp->erase(); 1069 } 1070 } 1071 1072 // Perform the replacement in `op`. 1073 LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, 1074 Operation *op, 1075 ArrayRef<Value> extraIndices, 1076 AffineMap indexRemap, 1077 ArrayRef<Value> extraOperands, 1078 ArrayRef<Value> symbolOperands, 1079 bool allowNonDereferencingOps) { 1080 unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank(); 1081 (void)newMemRefRank; // unused in opt mode 1082 unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank(); 1083 (void)oldMemRefRank; // unused in opt mode 1084 if (indexRemap) { 1085 assert(indexRemap.getNumSymbols() == symbolOperands.size() && 1086 "symbolic operand count mismatch"); 1087 assert(indexRemap.getNumInputs() == 1088 extraOperands.size() + oldMemRefRank + symbolOperands.size()); 1089 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); 1090 } else { 1091 assert(oldMemRefRank + extraIndices.size() == newMemRefRank); 1092 } 1093 1094 // Assert same elemental type. 1095 assert(oldMemRef.getType().cast<MemRefType>().getElementType() == 1096 newMemRef.getType().cast<MemRefType>().getElementType()); 1097 1098 SmallVector<unsigned, 2> usePositions; 1099 for (const auto &opEntry : llvm::enumerate(op->getOperands())) { 1100 if (opEntry.value() == oldMemRef) 1101 usePositions.push_back(opEntry.index()); 1102 } 1103 1104 // If memref doesn't appear, nothing to do. 1105 if (usePositions.empty()) 1106 return success(); 1107 1108 if (usePositions.size() > 1) { 1109 // TODO: extend it for this case when needed (rare). 1110 assert(false && "multiple dereferencing uses in a single op not supported"); 1111 return failure(); 1112 } 1113 1114 unsigned memRefOperandPos = usePositions.front(); 1115 1116 OpBuilder builder(op); 1117 // The following checks if op is dereferencing memref and performs the access 1118 // index rewrites. 1119 auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op); 1120 if (!affMapAccInterface) { 1121 if (!allowNonDereferencingOps) { 1122 // Failure: memref used in a non-dereferencing context (potentially 1123 // escapes); no replacement in these cases unless allowNonDereferencingOps 1124 // is set. 1125 return failure(); 1126 } 1127 op->setOperand(memRefOperandPos, newMemRef); 1128 return success(); 1129 } 1130 // Perform index rewrites for the dereferencing op and then replace the op 1131 NamedAttribute oldMapAttrPair = 1132 affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); 1133 AffineMap oldMap = oldMapAttrPair.getValue().cast<AffineMapAttr>().getValue(); 1134 unsigned oldMapNumInputs = oldMap.getNumInputs(); 1135 SmallVector<Value, 4> oldMapOperands( 1136 op->operand_begin() + memRefOperandPos + 1, 1137 op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); 1138 1139 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. 1140 SmallVector<Value, 4> oldMemRefOperands; 1141 SmallVector<Value, 4> affineApplyOps; 1142 oldMemRefOperands.reserve(oldMemRefRank); 1143 if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { 1144 for (auto resultExpr : oldMap.getResults()) { 1145 auto singleResMap = AffineMap::get(oldMap.getNumDims(), 1146 oldMap.getNumSymbols(), resultExpr); 1147 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, 1148 oldMapOperands); 1149 oldMemRefOperands.push_back(afOp); 1150 affineApplyOps.push_back(afOp); 1151 } 1152 } else { 1153 oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); 1154 } 1155 1156 // Construct new indices as a remap of the old ones if a remapping has been 1157 // provided. The indices of a memref come right after it, i.e., 1158 // at position memRefOperandPos + 1. 1159 SmallVector<Value, 4> remapOperands; 1160 remapOperands.reserve(extraOperands.size() + oldMemRefRank + 1161 symbolOperands.size()); 1162 remapOperands.append(extraOperands.begin(), extraOperands.end()); 1163 remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); 1164 remapOperands.append(symbolOperands.begin(), symbolOperands.end()); 1165 1166 SmallVector<Value, 4> remapOutputs; 1167 remapOutputs.reserve(oldMemRefRank); 1168 1169 if (indexRemap && 1170 indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { 1171 // Remapped indices. 1172 for (auto resultExpr : indexRemap.getResults()) { 1173 auto singleResMap = AffineMap::get( 1174 indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); 1175 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, 1176 remapOperands); 1177 remapOutputs.push_back(afOp); 1178 affineApplyOps.push_back(afOp); 1179 } 1180 } else { 1181 // No remapping specified. 1182 remapOutputs.assign(remapOperands.begin(), remapOperands.end()); 1183 } 1184 1185 SmallVector<Value, 4> newMapOperands; 1186 newMapOperands.reserve(newMemRefRank); 1187 1188 // Prepend 'extraIndices' in 'newMapOperands'. 1189 for (Value extraIndex : extraIndices) { 1190 assert(extraIndex.getDefiningOp()->getNumResults() == 1 && 1191 "single result op's expected to generate these indices"); 1192 assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && 1193 "invalid memory op index"); 1194 newMapOperands.push_back(extraIndex); 1195 } 1196 1197 // Append 'remapOutputs' to 'newMapOperands'. 1198 newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); 1199 1200 // Create new fully composed AffineMap for new op to be created. 1201 assert(newMapOperands.size() == newMemRefRank); 1202 auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); 1203 // TODO: Avoid creating/deleting temporary AffineApplyOps here. 1204 fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); 1205 newMap = simplifyAffineMap(newMap); 1206 canonicalizeMapAndOperands(&newMap, &newMapOperands); 1207 // Remove any affine.apply's that became dead as a result of composition. 1208 for (Value value : affineApplyOps) 1209 if (value.use_empty()) 1210 value.getDefiningOp()->erase(); 1211 1212 OperationState state(op->getLoc(), op->getName()); 1213 // Construct the new operation using this memref. 1214 state.operands.reserve(op->getNumOperands() + extraIndices.size()); 1215 // Insert the non-memref operands. 1216 state.operands.append(op->operand_begin(), 1217 op->operand_begin() + memRefOperandPos); 1218 // Insert the new memref value. 1219 state.operands.push_back(newMemRef); 1220 1221 // Insert the new memref map operands. 1222 state.operands.append(newMapOperands.begin(), newMapOperands.end()); 1223 1224 // Insert the remaining operands unmodified. 1225 state.operands.append(op->operand_begin() + memRefOperandPos + 1 + 1226 oldMapNumInputs, 1227 op->operand_end()); 1228 1229 // Result types don't change. Both memref's are of the same elemental type. 1230 state.types.reserve(op->getNumResults()); 1231 for (auto result : op->getResults()) 1232 state.types.push_back(result.getType()); 1233 1234 // Add attribute for 'newMap', other Attributes do not change. 1235 auto newMapAttr = AffineMapAttr::get(newMap); 1236 for (auto namedAttr : op->getAttrs()) { 1237 if (namedAttr.getName() == oldMapAttrPair.getName()) 1238 state.attributes.push_back({namedAttr.getName(), newMapAttr}); 1239 else 1240 state.attributes.push_back(namedAttr); 1241 } 1242 1243 // Create the new operation. 1244 auto *repOp = builder.createOperation(state); 1245 op->replaceAllUsesWith(repOp); 1246 op->erase(); 1247 1248 return success(); 1249 } 1250 1251 LogicalResult mlir::replaceAllMemRefUsesWith( 1252 Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices, 1253 AffineMap indexRemap, ArrayRef<Value> extraOperands, 1254 ArrayRef<Value> symbolOperands, Operation *domOpFilter, 1255 Operation *postDomOpFilter, bool allowNonDereferencingOps, 1256 bool replaceInDeallocOp) { 1257 unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank(); 1258 (void)newMemRefRank; // unused in opt mode 1259 unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank(); 1260 (void)oldMemRefRank; 1261 if (indexRemap) { 1262 assert(indexRemap.getNumSymbols() == symbolOperands.size() && 1263 "symbol operand count mismatch"); 1264 assert(indexRemap.getNumInputs() == 1265 extraOperands.size() + oldMemRefRank + symbolOperands.size()); 1266 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); 1267 } else { 1268 assert(oldMemRefRank + extraIndices.size() == newMemRefRank); 1269 } 1270 1271 // Assert same elemental type. 1272 assert(oldMemRef.getType().cast<MemRefType>().getElementType() == 1273 newMemRef.getType().cast<MemRefType>().getElementType()); 1274 1275 std::unique_ptr<DominanceInfo> domInfo; 1276 std::unique_ptr<PostDominanceInfo> postDomInfo; 1277 if (domOpFilter) 1278 domInfo = 1279 std::make_unique<DominanceInfo>(domOpFilter->getParentOfType<FuncOp>()); 1280 1281 if (postDomOpFilter) 1282 postDomInfo = std::make_unique<PostDominanceInfo>( 1283 postDomOpFilter->getParentOfType<FuncOp>()); 1284 1285 // Walk all uses of old memref; collect ops to perform replacement. We use a 1286 // DenseSet since an operation could potentially have multiple uses of a 1287 // memref (although rare), and the replacement later is going to erase ops. 1288 DenseSet<Operation *> opsToReplace; 1289 for (auto *op : oldMemRef.getUsers()) { 1290 // Skip this use if it's not dominated by domOpFilter. 1291 if (domOpFilter && !domInfo->dominates(domOpFilter, op)) 1292 continue; 1293 1294 // Skip this use if it's not post-dominated by postDomOpFilter. 1295 if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op)) 1296 continue; 1297 1298 // Skip dealloc's - no replacement is necessary, and a memref replacement 1299 // at other uses doesn't hurt these dealloc's. 1300 if (isa<memref::DeallocOp>(op) && !replaceInDeallocOp) 1301 continue; 1302 1303 // Check if the memref was used in a non-dereferencing context. It is fine 1304 // for the memref to be used in a non-dereferencing way outside of the 1305 // region where this replacement is happening. 1306 if (!isa<AffineMapAccessInterface>(*op)) { 1307 if (!allowNonDereferencingOps) { 1308 LLVM_DEBUG(llvm::dbgs() 1309 << "Memref replacement failed: non-deferencing memref op: \n" 1310 << *op << '\n'); 1311 return failure(); 1312 } 1313 // Non-dereferencing ops with the MemRefsNormalizable trait are 1314 // supported for replacement. 1315 if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) { 1316 LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a " 1317 "memrefs normalizable trait: \n" 1318 << *op << '\n'); 1319 return failure(); 1320 } 1321 } 1322 1323 // We'll first collect and then replace --- since replacement erases the op 1324 // that has the use, and that op could be postDomFilter or domFilter itself! 1325 opsToReplace.insert(op); 1326 } 1327 1328 for (auto *op : opsToReplace) { 1329 if (failed(replaceAllMemRefUsesWith( 1330 oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands, 1331 symbolOperands, allowNonDereferencingOps))) 1332 llvm_unreachable("memref replacement guaranteed to succeed here"); 1333 } 1334 1335 return success(); 1336 } 1337 1338 /// Given an operation, inserts one or more single result affine 1339 /// apply operations, results of which are exclusively used by this operation 1340 /// operation. The operands of these newly created affine apply ops are 1341 /// guaranteed to be loop iterators or terminal symbols of a function. 1342 /// 1343 /// Before 1344 /// 1345 /// affine.for %i = 0 to #map(%N) 1346 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) 1347 /// "send"(%idx, %A, ...) 1348 /// "compute"(%idx) 1349 /// 1350 /// After 1351 /// 1352 /// affine.for %i = 0 to #map(%N) 1353 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) 1354 /// "send"(%idx, %A, ...) 1355 /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) 1356 /// "compute"(%idx_) 1357 /// 1358 /// This allows applying different transformations on send and compute (for eg. 1359 /// different shifts/delays). 1360 /// 1361 /// Returns nullptr either if none of opInst's operands were the result of an 1362 /// affine.apply and thus there was no affine computation slice to create, or if 1363 /// all the affine.apply op's supplying operands to this opInst did not have any 1364 /// uses besides this opInst; otherwise returns the list of affine.apply 1365 /// operations created in output argument `sliceOps`. 1366 void mlir::createAffineComputationSlice( 1367 Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) { 1368 // Collect all operands that are results of affine apply ops. 1369 SmallVector<Value, 4> subOperands; 1370 subOperands.reserve(opInst->getNumOperands()); 1371 for (auto operand : opInst->getOperands()) 1372 if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp())) 1373 subOperands.push_back(operand); 1374 1375 // Gather sequence of AffineApplyOps reachable from 'subOperands'. 1376 SmallVector<Operation *, 4> affineApplyOps; 1377 getReachableAffineApplyOps(subOperands, affineApplyOps); 1378 // Skip transforming if there are no affine maps to compose. 1379 if (affineApplyOps.empty()) 1380 return; 1381 1382 // Check if all uses of the affine apply op's lie only in this op op, in 1383 // which case there would be nothing to do. 1384 bool localized = true; 1385 for (auto *op : affineApplyOps) { 1386 for (auto result : op->getResults()) { 1387 for (auto *user : result.getUsers()) { 1388 if (user != opInst) { 1389 localized = false; 1390 break; 1391 } 1392 } 1393 } 1394 } 1395 if (localized) 1396 return; 1397 1398 OpBuilder builder(opInst); 1399 SmallVector<Value, 4> composedOpOperands(subOperands); 1400 auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); 1401 fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); 1402 1403 // Create an affine.apply for each of the map results. 1404 sliceOps->reserve(composedMap.getNumResults()); 1405 for (auto resultExpr : composedMap.getResults()) { 1406 auto singleResMap = AffineMap::get(composedMap.getNumDims(), 1407 composedMap.getNumSymbols(), resultExpr); 1408 sliceOps->push_back(builder.create<AffineApplyOp>( 1409 opInst->getLoc(), singleResMap, composedOpOperands)); 1410 } 1411 1412 // Construct the new operands that include the results from the composed 1413 // affine apply op above instead of existing ones (subOperands). So, they 1414 // differ from opInst's operands only for those operands in 'subOperands', for 1415 // which they will be replaced by the corresponding one from 'sliceOps'. 1416 SmallVector<Value, 4> newOperands(opInst->getOperands()); 1417 for (unsigned i = 0, e = newOperands.size(); i < e; i++) { 1418 // Replace the subOperands from among the new operands. 1419 unsigned j, f; 1420 for (j = 0, f = subOperands.size(); j < f; j++) { 1421 if (newOperands[i] == subOperands[j]) 1422 break; 1423 } 1424 if (j < subOperands.size()) { 1425 newOperands[i] = (*sliceOps)[j]; 1426 } 1427 } 1428 for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) { 1429 opInst->setOperand(idx, newOperands[idx]); 1430 } 1431 } 1432 1433 /// Enum to set patterns of affine expr in tiled-layout map. 1434 /// TileFloorDiv: <dim expr> div <tile size> 1435 /// TileMod: <dim expr> mod <tile size> 1436 /// TileNone: None of the above 1437 /// Example: 1438 /// #tiled_2d_128x256 = affine_map<(d0, d1) 1439 /// -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)> 1440 /// "d0 div 128" and "d1 div 256" ==> TileFloorDiv 1441 /// "d0 mod 128" and "d1 mod 256" ==> TileMod 1442 enum TileExprPattern { TileFloorDiv, TileMod, TileNone }; 1443 1444 /// Check if `map` is a tiled layout. In the tiled layout, specific k dimensions 1445 /// being floordiv'ed by respective tile sizes appeare in a mod with the same 1446 /// tile sizes, and no other expression involves those k dimensions. This 1447 /// function stores a vector of tuples (`tileSizePos`) including AffineExpr for 1448 /// tile size, positions of corresponding `floordiv` and `mod`. If it is not a 1449 /// tiled layout, an empty vector is returned. 1450 static LogicalResult getTileSizePos( 1451 AffineMap map, 1452 SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) { 1453 // Create `floordivExprs` which is a vector of tuples including LHS and RHS of 1454 // `floordiv` and its position in `map` output. 1455 // Example: #tiled_2d_128x256 = affine_map<(d0, d1) 1456 // -> (d0 div 128, d1 div 256, d0 mod 128, d1 mod 256)> 1457 // In this example, `floordivExprs` includes {d0, 128, 0} and {d1, 256, 1}. 1458 SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs; 1459 unsigned pos = 0; 1460 for (AffineExpr expr : map.getResults()) { 1461 if (expr.getKind() == AffineExprKind::FloorDiv) { 1462 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 1463 if (binaryExpr.getRHS().isa<AffineConstantExpr>()) 1464 floordivExprs.emplace_back( 1465 std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos)); 1466 } 1467 pos++; 1468 } 1469 // Not tiled layout if `floordivExprs` is empty. 1470 if (floordivExprs.empty()) { 1471 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{}; 1472 return success(); 1473 } 1474 1475 // Check if LHS of `floordiv` is used in LHS of `mod`. If not used, `map` is 1476 // not tiled layout. 1477 for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) { 1478 AffineExpr floordivExprLHS = std::get<0>(fexpr); 1479 AffineExpr floordivExprRHS = std::get<1>(fexpr); 1480 unsigned floordivPos = std::get<2>(fexpr); 1481 1482 // Walk affinexpr of `map` output except `fexpr`, and check if LHS and RHS 1483 // of `fexpr` are used in LHS and RHS of `mod`. If LHS of `fexpr` is used 1484 // other expr, the map is not tiled layout. Example of non tiled layout: 1485 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 floordiv 256)> 1486 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 128)> 1487 // affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 mod 256, d2 mod 1488 // 256)> 1489 bool found = false; 1490 pos = 0; 1491 for (AffineExpr expr : map.getResults()) { 1492 bool notTiled = false; 1493 if (pos != floordivPos) { 1494 expr.walk([&](AffineExpr e) { 1495 if (e == floordivExprLHS) { 1496 if (expr.getKind() == AffineExprKind::Mod) { 1497 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>(); 1498 // If LHS and RHS of `mod` are the same with those of floordiv. 1499 if (floordivExprLHS == binaryExpr.getLHS() && 1500 floordivExprRHS == binaryExpr.getRHS()) { 1501 // Save tile size (RHS of `mod`), and position of `floordiv` and 1502 // `mod` if same expr with `mod` is not found yet. 1503 if (!found) { 1504 tileSizePos.emplace_back( 1505 std::make_tuple(binaryExpr.getRHS(), floordivPos, pos)); 1506 found = true; 1507 } else { 1508 // Non tiled layout: Have multilpe `mod` with the same LHS. 1509 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 1510 // mod 256, d2 mod 256)> 1511 notTiled = true; 1512 } 1513 } else { 1514 // Non tiled layout: RHS of `mod` is different from `floordiv`. 1515 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 1516 // mod 128)> 1517 notTiled = true; 1518 } 1519 } else { 1520 // Non tiled layout: LHS is the same, but not `mod`. 1521 // eg. affine_map<(d0, d1, d2) -> (d0, d1, d2 floordiv 256, d2 1522 // floordiv 256)> 1523 notTiled = true; 1524 } 1525 } 1526 }); 1527 } 1528 if (notTiled) { 1529 tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{}; 1530 return success(); 1531 } 1532 pos++; 1533 } 1534 } 1535 return success(); 1536 } 1537 1538 /// Check if `dim` dimension of memrefType with `layoutMap` becomes dynamic 1539 /// after normalization. Dimensions that include dynamic dimensions in the map 1540 /// output will become dynamic dimensions. Return true if `dim` is dynamic 1541 /// dimension. 1542 /// 1543 /// Example: 1544 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)> 1545 /// 1546 /// If d1 is dynamic dimension, 2nd and 3rd dimension of map output are dynamic. 1547 /// memref<4x?xf32, #map0> ==> memref<4x?x?xf32> 1548 static bool 1549 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap, 1550 SmallVectorImpl<unsigned> &inMemrefTypeDynDims, 1551 MLIRContext *context) { 1552 bool isDynamicDim = false; 1553 AffineExpr expr = layoutMap.getResults()[dim]; 1554 // Check if affine expr of the dimension includes dynamic dimension of input 1555 // memrefType. 1556 expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) { 1557 if (e.isa<AffineDimExpr>()) { 1558 for (unsigned dm : inMemrefTypeDynDims) { 1559 if (e == getAffineDimExpr(dm, context)) { 1560 isDynamicDim = true; 1561 } 1562 } 1563 } 1564 }); 1565 return isDynamicDim; 1566 } 1567 1568 /// Create affine expr to calculate dimension size for a tiled-layout map. 1569 static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput, 1570 TileExprPattern pat) { 1571 // Create map output for the patterns. 1572 // "floordiv <tile size>" ==> "ceildiv <tile size>" 1573 // "mod <tile size>" ==> "<tile size>" 1574 AffineExpr newMapOutput; 1575 AffineBinaryOpExpr binaryExpr = nullptr; 1576 switch (pat) { 1577 case TileExprPattern::TileMod: 1578 binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>(); 1579 newMapOutput = binaryExpr.getRHS(); 1580 break; 1581 case TileExprPattern::TileFloorDiv: 1582 binaryExpr = oldMapOutput.cast<AffineBinaryOpExpr>(); 1583 newMapOutput = getAffineBinaryOpExpr( 1584 AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS()); 1585 break; 1586 default: 1587 newMapOutput = oldMapOutput; 1588 } 1589 return newMapOutput; 1590 } 1591 1592 /// Create new maps to calculate each dimension size of `newMemRefType`, and 1593 /// create `newDynamicSizes` from them by using AffineApplyOp. 1594 /// 1595 /// Steps for normalizing dynamic memrefs for a tiled layout map 1596 /// Example: 1597 /// #map0 = affine_map<(d0, d1) -> (d0, d1 floordiv 32, d1 mod 32)> 1598 /// %0 = dim %arg0, %c1 :memref<4x?xf32> 1599 /// %1 = alloc(%0) : memref<4x?xf32, #map0> 1600 /// 1601 /// (Before this function) 1602 /// 1. Check if `map`(#map0) is a tiled layout using `getTileSizePos()`. Only 1603 /// single layout map is supported. 1604 /// 1605 /// 2. Create normalized memrefType using `isNormalizedMemRefDynamicDim()`. It 1606 /// is memref<4x?x?xf32> in the above example. 1607 /// 1608 /// (In this function) 1609 /// 3. Create new maps to calculate each dimension of the normalized memrefType 1610 /// using `createDimSizeExprForTiledLayout()`. In the tiled layout, the 1611 /// dimension size can be calculated by replacing "floordiv <tile size>" with 1612 /// "ceildiv <tile size>" and "mod <tile size>" with "<tile size>". 1613 /// - New map in the above example 1614 /// #map0 = affine_map<(d0, d1) -> (d0)> 1615 /// #map1 = affine_map<(d0, d1) -> (d1 ceildiv 32)> 1616 /// #map2 = affine_map<(d0, d1) -> (32)> 1617 /// 1618 /// 4. Create AffineApplyOp to apply the new maps. The output of AffineApplyOp 1619 /// is used in dynamicSizes of new AllocOp. 1620 /// %0 = dim %arg0, %c1 : memref<4x?xf32> 1621 /// %c4 = arith.constant 4 : index 1622 /// %1 = affine.apply #map1(%c4, %0) 1623 /// %2 = affine.apply #map2(%c4, %0) 1624 static void createNewDynamicSizes(MemRefType oldMemRefType, 1625 MemRefType newMemRefType, AffineMap map, 1626 memref::AllocOp *allocOp, OpBuilder b, 1627 SmallVectorImpl<Value> &newDynamicSizes) { 1628 // Create new input for AffineApplyOp. 1629 SmallVector<Value, 4> inAffineApply; 1630 ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape(); 1631 unsigned dynIdx = 0; 1632 for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) { 1633 if (oldMemRefShape[d] < 0) { 1634 // Use dynamicSizes of allocOp for dynamic dimension. 1635 inAffineApply.emplace_back(allocOp->dynamicSizes()[dynIdx]); 1636 dynIdx++; 1637 } else { 1638 // Create ConstantOp for static dimension. 1639 Attribute constantAttr = 1640 b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]); 1641 inAffineApply.emplace_back( 1642 b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr)); 1643 } 1644 } 1645 1646 // Create new map to calculate each dimension size of new memref for each 1647 // original map output. Only for dynamic dimesion of `newMemRefType`. 1648 unsigned newDimIdx = 0; 1649 ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape(); 1650 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos; 1651 (void)getTileSizePos(map, tileSizePos); 1652 for (AffineExpr expr : map.getResults()) { 1653 if (newMemRefShape[newDimIdx] < 0) { 1654 // Create new maps to calculate each dimension size of new memref. 1655 enum TileExprPattern pat = TileExprPattern::TileNone; 1656 for (auto pos : tileSizePos) { 1657 if (newDimIdx == std::get<1>(pos)) 1658 pat = TileExprPattern::TileFloorDiv; 1659 else if (newDimIdx == std::get<2>(pos)) 1660 pat = TileExprPattern::TileMod; 1661 } 1662 AffineExpr newMapOutput = createDimSizeExprForTiledLayout(expr, pat); 1663 AffineMap newMap = 1664 AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput); 1665 Value affineApp = 1666 b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply); 1667 newDynamicSizes.emplace_back(affineApp); 1668 } 1669 newDimIdx++; 1670 } 1671 } 1672 1673 // TODO: Currently works for static memrefs with a single layout map. 1674 LogicalResult mlir::normalizeMemRef(memref::AllocOp *allocOp) { 1675 MemRefType memrefType = allocOp->getType(); 1676 OpBuilder b(*allocOp); 1677 1678 // Fetch a new memref type after normalizing the old memref to have an 1679 // identity map layout. 1680 MemRefType newMemRefType = 1681 normalizeMemRefType(memrefType, b, allocOp->symbolOperands().size()); 1682 if (newMemRefType == memrefType) 1683 // Either memrefType already had an identity map or the map couldn't be 1684 // transformed to an identity map. 1685 return failure(); 1686 1687 Value oldMemRef = allocOp->getResult(); 1688 1689 SmallVector<Value, 4> symbolOperands(allocOp->symbolOperands()); 1690 AffineMap layoutMap = memrefType.getLayout().getAffineMap(); 1691 memref::AllocOp newAlloc; 1692 // Check if `layoutMap` is a tiled layout. Only single layout map is 1693 // supported for normalizing dynamic memrefs. 1694 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos; 1695 (void)getTileSizePos(layoutMap, tileSizePos); 1696 if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) { 1697 MemRefType oldMemRefType = oldMemRef.getType().cast<MemRefType>(); 1698 SmallVector<Value, 4> newDynamicSizes; 1699 createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b, 1700 newDynamicSizes); 1701 // Add the new dynamic sizes in new AllocOp. 1702 newAlloc = 1703 b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType, 1704 newDynamicSizes, allocOp->alignmentAttr()); 1705 } else { 1706 newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType, 1707 allocOp->alignmentAttr()); 1708 } 1709 // Replace all uses of the old memref. 1710 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, 1711 /*extraIndices=*/{}, 1712 /*indexRemap=*/layoutMap, 1713 /*extraOperands=*/{}, 1714 /*symbolOperands=*/symbolOperands, 1715 /*domOpFilter=*/nullptr, 1716 /*postDomOpFilter=*/nullptr, 1717 /*allowNonDereferencingOps=*/true))) { 1718 // If it failed (due to escapes for example), bail out. 1719 newAlloc.erase(); 1720 return failure(); 1721 } 1722 // Replace any uses of the original alloc op and erase it. All remaining uses 1723 // have to be dealloc's; RAMUW above would've failed otherwise. 1724 assert(llvm::all_of(oldMemRef.getUsers(), [](Operation *op) { 1725 return isa<memref::DeallocOp>(op); 1726 })); 1727 oldMemRef.replaceAllUsesWith(newAlloc); 1728 allocOp->erase(); 1729 return success(); 1730 } 1731 1732 MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b, 1733 unsigned numSymbolicOperands) { 1734 unsigned rank = memrefType.getRank(); 1735 if (rank == 0) 1736 return memrefType; 1737 1738 if (memrefType.getLayout().isIdentity()) { 1739 // Either no maps is associated with this memref or this memref has 1740 // a trivial (identity) map. 1741 return memrefType; 1742 } 1743 AffineMap layoutMap = memrefType.getLayout().getAffineMap(); 1744 1745 // We don't do any checks for one-to-one'ness; we assume that it is 1746 // one-to-one. 1747 1748 // Normalize only static memrefs and dynamic memrefs with a tiled-layout map 1749 // for now. 1750 // TODO: Normalize the other types of dynamic memrefs. 1751 SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos; 1752 (void)getTileSizePos(layoutMap, tileSizePos); 1753 if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty()) 1754 return memrefType; 1755 1756 // We have a single map that is not an identity map. Create a new memref 1757 // with the right shape and an identity layout map. 1758 ArrayRef<int64_t> shape = memrefType.getShape(); 1759 // FlatAffineConstraint may later on use symbolicOperands. 1760 FlatAffineConstraints fac(rank, numSymbolicOperands); 1761 SmallVector<unsigned, 4> memrefTypeDynDims; 1762 for (unsigned d = 0; d < rank; ++d) { 1763 // Use constraint system only in static dimensions. 1764 if (shape[d] > 0) { 1765 fac.addBound(FlatAffineConstraints::LB, d, 0); 1766 fac.addBound(FlatAffineConstraints::UB, d, shape[d] - 1); 1767 } else { 1768 memrefTypeDynDims.emplace_back(d); 1769 } 1770 } 1771 // We compose this map with the original index (logical) space to derive 1772 // the upper bounds for the new index space. 1773 unsigned newRank = layoutMap.getNumResults(); 1774 if (failed(fac.composeMatchingMap(layoutMap))) 1775 return memrefType; 1776 // TODO: Handle semi-affine maps. 1777 // Project out the old data dimensions. 1778 fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds()); 1779 SmallVector<int64_t, 4> newShape(newRank); 1780 for (unsigned d = 0; d < newRank; ++d) { 1781 // Check if each dimension of normalized memrefType is dynamic. 1782 bool isDynDim = isNormalizedMemRefDynamicDim( 1783 d, layoutMap, memrefTypeDynDims, b.getContext()); 1784 if (isDynDim) { 1785 newShape[d] = -1; 1786 } else { 1787 // The lower bound for the shape is always zero. 1788 auto ubConst = fac.getConstantBound(FlatAffineConstraints::UB, d); 1789 // For a static memref and an affine map with no symbols, this is 1790 // always bounded. 1791 assert(ubConst.hasValue() && "should always have an upper bound"); 1792 if (ubConst.getValue() < 0) 1793 // This is due to an invalid map that maps to a negative space. 1794 return memrefType; 1795 // If dimension of new memrefType is dynamic, the value is -1. 1796 newShape[d] = ubConst.getValue() + 1; 1797 } 1798 } 1799 1800 // Create the new memref type after trivializing the old layout map. 1801 MemRefType newMemRefType = 1802 MemRefType::Builder(memrefType) 1803 .setShape(newShape) 1804 .setLayout(AffineMapAttr::get(b.getMultiDimIdentityMap(newRank))); 1805 1806 return newMemRefType; 1807 } 1808