1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous loop transformation routines. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Utils/Utils.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Support/MathExtras.h" 22 #include "mlir/Transforms/RegionUtils.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/SetVector.h" 25 #include "llvm/ADT/SmallPtrSet.h" 26 27 using namespace mlir; 28 29 namespace { 30 // This structure is to pass and return sets of loop parameters without 31 // confusing the order. 32 struct LoopParams { 33 Value lowerBound; 34 Value upperBound; 35 Value step; 36 }; 37 } // namespace 38 39 scf::ForOp 40 mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, 41 ValueRange newIterOperands, 42 const NewYieldValueFn &newYieldValuesFn) { 43 // Create a new loop before the existing one, with the extra operands. 44 OpBuilder::InsertionGuard g(builder); 45 builder.setInsertionPoint(loop); 46 auto operands = llvm::to_vector(loop.getIterOperands()); 47 operands.append(newIterOperands.begin(), newIterOperands.end()); 48 scf::ForOp newLoop = builder.create<scf::ForOp>( 49 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), 50 operands, [](OpBuilder &, Location, Value, ValueRange) {}); 51 52 Block *loopBody = loop.getBody(); 53 Block *newLoopBody = newLoop.getBody(); 54 55 // Move the body of the original loop to the new loop. 56 newLoopBody->getOperations().splice(newLoopBody->end(), 57 loopBody->getOperations()); 58 59 // Generate the new yield values to use by using the callback and append the 60 // yield values to the scf.yield operation. 61 auto yield = cast<scf::YieldOp>(newLoopBody->getTerminator()); 62 ArrayRef<BlockArgument> newBBArgs = 63 newLoopBody->getArguments().take_back(newIterOperands.size()); 64 { 65 OpBuilder::InsertionGuard g(builder); 66 builder.setInsertionPoint(yield); 67 SmallVector<Value> newYieldedValues = 68 newYieldValuesFn(builder, loop.getLoc(), newBBArgs); 69 assert(newIterOperands.size() == newYieldedValues.size() && 70 "expected as many new yield values as new iter operands"); 71 yield.getResultsMutable().append(newYieldedValues); 72 } 73 74 // Remap the BlockArguments from the original loop to the new loop 75 // BlockArguments. 76 ArrayRef<BlockArgument> bbArgs = loopBody->getArguments(); 77 for (auto it : 78 llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) 79 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 80 81 // Replace all uses of `newIterOperands` with the corresponding basic block 82 // arguments. 83 for (auto it : llvm::zip(newIterOperands, newBBArgs)) { 84 std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { 85 Operation *user = use.getOwner(); 86 return newLoop->isProperAncestor(user); 87 }); 88 } 89 90 // Replace all uses of the original loop with corresponding values from the 91 // new loop. 92 loop.replaceAllUsesWith( 93 newLoop.getResults().take_front(loop.getNumResults())); 94 95 // Add a fake yield to the original loop body that just returns the 96 // BlockArguments corresponding to the iter_args. This makes it a no-op loop. 97 // The loop is dead. The caller is expected to erase it. 98 builder.setInsertionPointToEnd(loopBody); 99 builder.create<scf::YieldOp>(loop->getLoc(), loop.getRegionIterArgs()); 100 101 return newLoop; 102 } 103 104 /// Outline a region with a single block into a new FuncOp. 105 /// Assumes the FuncOp result types is the type of the yielded operands of the 106 /// single block. This constraint makes it easy to determine the result. 107 /// This method also clones the `arith::ConstantIndexOp` at the start of 108 /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is 109 /// provided, it will be set to point to the operation that calls the outlined 110 /// function. 111 // TODO: support more than single-block regions. 112 // TODO: more flexible constant handling. 113 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter, 114 Location loc, 115 Region ®ion, 116 StringRef funcName, 117 func::CallOp *callOp) { 118 assert(!funcName.empty() && "funcName cannot be empty"); 119 if (!region.hasOneBlock()) 120 return failure(); 121 122 Block *originalBlock = ®ion.front(); 123 Operation *originalTerminator = originalBlock->getTerminator(); 124 125 // Outline before current function. 126 OpBuilder::InsertionGuard g(rewriter); 127 rewriter.setInsertionPoint(region.getParentOfType<func::FuncOp>()); 128 129 SetVector<Value> captures; 130 getUsedValuesDefinedAbove(region, captures); 131 132 ValueRange outlinedValues(captures.getArrayRef()); 133 SmallVector<Type> outlinedFuncArgTypes; 134 SmallVector<Location> outlinedFuncArgLocs; 135 // Region's arguments are exactly the first block's arguments as per 136 // Region::getArguments(). 137 // Func's arguments are cat(regions's arguments, captures arguments). 138 for (BlockArgument arg : region.getArguments()) { 139 outlinedFuncArgTypes.push_back(arg.getType()); 140 outlinedFuncArgLocs.push_back(arg.getLoc()); 141 } 142 for (Value value : outlinedValues) { 143 outlinedFuncArgTypes.push_back(value.getType()); 144 outlinedFuncArgLocs.push_back(value.getLoc()); 145 } 146 FunctionType outlinedFuncType = 147 FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes, 148 originalTerminator->getOperandTypes()); 149 auto outlinedFunc = 150 rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType); 151 Block *outlinedFuncBody = outlinedFunc.addEntryBlock(); 152 153 // Merge blocks while replacing the original block operands. 154 // Warning: `mergeBlocks` erases the original block, reconstruct it later. 155 int64_t numOriginalBlockArguments = originalBlock->getNumArguments(); 156 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments(); 157 { 158 OpBuilder::InsertionGuard g(rewriter); 159 rewriter.setInsertionPointToEnd(outlinedFuncBody); 160 rewriter.mergeBlocks( 161 originalBlock, outlinedFuncBody, 162 outlinedFuncBlockArgs.take_front(numOriginalBlockArguments)); 163 // Explicitly set up a new ReturnOp terminator. 164 rewriter.setInsertionPointToEnd(outlinedFuncBody); 165 rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(), 166 originalTerminator->getOperands()); 167 } 168 169 // Reconstruct the block that was deleted and add a 170 // terminator(call_results). 171 Block *newBlock = rewriter.createBlock( 172 ®ion, region.begin(), 173 TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments), 174 ArrayRef<Location>(outlinedFuncArgLocs) 175 .take_front(numOriginalBlockArguments)); 176 { 177 OpBuilder::InsertionGuard g(rewriter); 178 rewriter.setInsertionPointToEnd(newBlock); 179 SmallVector<Value> callValues; 180 llvm::append_range(callValues, newBlock->getArguments()); 181 llvm::append_range(callValues, outlinedValues); 182 auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues); 183 if (callOp) 184 *callOp = call; 185 186 // `originalTerminator` was moved to `outlinedFuncBody` and is still valid. 187 // Clone `originalTerminator` to take the callOp results then erase it from 188 // `outlinedFuncBody`. 189 BlockAndValueMapping bvm; 190 bvm.map(originalTerminator->getOperands(), call->getResults()); 191 rewriter.clone(*originalTerminator, bvm); 192 rewriter.eraseOp(originalTerminator); 193 } 194 195 // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`. 196 // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`. 197 for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back( 198 outlinedValues.size()))) { 199 Value orig = std::get<0>(it); 200 Value repl = std::get<1>(it); 201 { 202 OpBuilder::InsertionGuard g(rewriter); 203 rewriter.setInsertionPointToStart(outlinedFuncBody); 204 if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) { 205 BlockAndValueMapping bvm; 206 repl = rewriter.clone(*cst, bvm)->getResult(0); 207 } 208 } 209 orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) { 210 return outlinedFunc->isProperAncestor(opOperand.getOwner()); 211 }); 212 } 213 214 return outlinedFunc; 215 } 216 217 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp, 218 func::FuncOp *thenFn, StringRef thenFnName, 219 func::FuncOp *elseFn, StringRef elseFnName) { 220 IRRewriter rewriter(b); 221 Location loc = ifOp.getLoc(); 222 FailureOr<func::FuncOp> outlinedFuncOpOrFailure; 223 if (thenFn && !ifOp.getThenRegion().empty()) { 224 outlinedFuncOpOrFailure = outlineSingleBlockRegion( 225 rewriter, loc, ifOp.getThenRegion(), thenFnName); 226 if (failed(outlinedFuncOpOrFailure)) 227 return failure(); 228 *thenFn = *outlinedFuncOpOrFailure; 229 } 230 if (elseFn && !ifOp.getElseRegion().empty()) { 231 outlinedFuncOpOrFailure = outlineSingleBlockRegion( 232 rewriter, loc, ifOp.getElseRegion(), elseFnName); 233 if (failed(outlinedFuncOpOrFailure)) 234 return failure(); 235 *elseFn = *outlinedFuncOpOrFailure; 236 } 237 return success(); 238 } 239 240 bool mlir::getInnermostParallelLoops(Operation *rootOp, 241 SmallVectorImpl<scf::ParallelOp> &result) { 242 assert(rootOp != nullptr && "Root operation must not be a nullptr."); 243 bool rootEnclosesPloops = false; 244 for (Region ®ion : rootOp->getRegions()) { 245 for (Block &block : region.getBlocks()) { 246 for (Operation &op : block) { 247 bool enclosesPloops = getInnermostParallelLoops(&op, result); 248 rootEnclosesPloops |= enclosesPloops; 249 if (auto ploop = dyn_cast<scf::ParallelOp>(op)) { 250 rootEnclosesPloops = true; 251 252 // Collect parallel loop if it is an innermost one. 253 if (!enclosesPloops) 254 result.push_back(ploop); 255 } 256 } 257 } 258 } 259 return rootEnclosesPloops; 260 } 261 262 // Build the IR that performs ceil division of a positive value by a constant: 263 // ceildiv(a, B) = divis(a + (B-1), B) 264 // where divis is rounding-to-zero division. 265 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, 266 int64_t divisor) { 267 assert(divisor > 0 && "expected positive divisor"); 268 assert(dividend.getType().isIndex() && "expected index-typed value"); 269 270 Value divisorMinusOneCst = 271 builder.create<arith::ConstantIndexOp>(loc, divisor - 1); 272 Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor); 273 Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst); 274 return builder.create<arith::DivSIOp>(loc, sum, divisorCst); 275 } 276 277 // Build the IR that performs ceil division of a positive value by another 278 // positive value: 279 // ceildiv(a, b) = divis(a + (b - 1), b) 280 // where divis is rounding-to-zero division. 281 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, 282 Value divisor) { 283 assert(dividend.getType().isIndex() && "expected index-typed value"); 284 285 Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1); 286 Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne); 287 Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne); 288 return builder.create<arith::DivSIOp>(loc, sum, divisor); 289 } 290 291 /// Helper to replace uses of loop carried values (iter_args) and loop 292 /// yield values while promoting single iteration scf.for ops. 293 static void replaceIterArgsAndYieldResults(scf::ForOp forOp) { 294 // Replace uses of iter arguments with iter operands (initial values). 295 auto iterOperands = forOp.getIterOperands(); 296 auto iterArgs = forOp.getRegionIterArgs(); 297 for (auto e : llvm::zip(iterOperands, iterArgs)) 298 std::get<1>(e).replaceAllUsesWith(std::get<0>(e)); 299 300 // Replace uses of loop results with the values yielded by the loop. 301 auto outerResults = forOp.getResults(); 302 auto innerResults = forOp.getBody()->getTerminator()->getOperands(); 303 for (auto e : llvm::zip(outerResults, innerResults)) 304 std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); 305 } 306 307 /// Promotes the loop body of a forOp to its containing block if the forOp 308 /// it can be determined that the loop has a single iteration. 309 LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) { 310 auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>(); 311 auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>(); 312 auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>(); 313 if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 || 314 ubCstOp.value() < 0 || stepCstOp.value() < 0) 315 return failure(); 316 int64_t tripCount = 317 mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value()); 318 if (tripCount != 1) 319 return failure(); 320 auto iv = forOp.getInductionVar(); 321 iv.replaceAllUsesWith(lbCstOp); 322 323 replaceIterArgsAndYieldResults(forOp); 324 325 // Move the loop body operations, except for its terminator, to the loop's 326 // containing block. 327 auto *parentBlock = forOp->getBlock(); 328 forOp.getBody()->getTerminator()->erase(); 329 parentBlock->getOperations().splice(Block::iterator(forOp), 330 forOp.getBody()->getOperations()); 331 forOp.erase(); 332 return success(); 333 } 334 335 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with 336 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap 337 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each 338 /// unrolled iteration using annotateFn. 339 static void generateUnrolledLoop( 340 Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, 341 function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn, 342 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn, 343 ValueRange iterArgs, ValueRange yieldedValues) { 344 // Builder to insert unrolled bodies just before the terminator of the body of 345 // 'forOp'. 346 auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); 347 348 if (!annotateFn) 349 annotateFn = [](unsigned, Operation *, OpBuilder) {}; 350 351 // Keep a pointer to the last non-terminator operation in the original block 352 // so that we know what to clone (since we are doing this in-place). 353 Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2); 354 355 // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). 356 SmallVector<Value, 4> lastYielded(yieldedValues); 357 358 for (unsigned i = 1; i < unrollFactor; i++) { 359 BlockAndValueMapping operandMap; 360 361 // Prepare operand map. 362 operandMap.map(iterArgs, lastYielded); 363 364 // If the induction variable is used, create a remapping to the value for 365 // this unrolled instance. 366 if (!forOpIV.use_empty()) { 367 Value ivUnroll = ivRemapFn(i, forOpIV, builder); 368 operandMap.map(forOpIV, ivUnroll); 369 } 370 371 // Clone the original body of 'forOp'. 372 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) { 373 Operation *clonedOp = builder.clone(*it, operandMap); 374 annotateFn(i, clonedOp, builder); 375 } 376 377 // Update yielded values. 378 for (unsigned i = 0, e = lastYielded.size(); i < e; i++) 379 lastYielded[i] = operandMap.lookup(yieldedValues[i]); 380 } 381 382 // Make sure we annotate the Ops in the original body. We do this last so that 383 // any annotations are not copied into the cloned Ops above. 384 for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) 385 annotateFn(0, &*it, builder); 386 387 // Update operands of the yield statement. 388 loopBodyBlock->getTerminator()->setOperands(lastYielded); 389 } 390 391 /// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled. 392 LogicalResult mlir::loopUnrollByFactor( 393 scf::ForOp forOp, uint64_t unrollFactor, 394 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) { 395 assert(unrollFactor > 0 && "expected positive unroll factor"); 396 397 // Return if the loop body is empty. 398 if (llvm::hasSingleElement(forOp.getBody()->getOperations())) 399 return success(); 400 401 // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate 402 // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. 403 OpBuilder boundsBuilder(forOp); 404 auto loc = forOp.getLoc(); 405 auto step = forOp.getStep(); 406 Value upperBoundUnrolled; 407 Value stepUnrolled; 408 bool generateEpilogueLoop = true; 409 410 auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>(); 411 auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>(); 412 auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>(); 413 if (lbCstOp && ubCstOp && stepCstOp) { 414 // Constant loop bounds computation. 415 int64_t lbCst = lbCstOp.value(); 416 int64_t ubCst = ubCstOp.value(); 417 int64_t stepCst = stepCstOp.value(); 418 assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 && 419 "expected positive loop bounds and step"); 420 int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst); 421 422 if (unrollFactor == 1) { 423 if (tripCount == 1 && failed(promoteIfSingleIteration(forOp))) 424 return failure(); 425 return success(); 426 } 427 428 int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor); 429 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; 430 assert(upperBoundUnrolledCst <= ubCst); 431 int64_t stepUnrolledCst = stepCst * unrollFactor; 432 433 // Create constant for 'upperBoundUnrolled' and set epilogue loop flag. 434 generateEpilogueLoop = upperBoundUnrolledCst < ubCst; 435 if (generateEpilogueLoop) 436 upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>( 437 loc, upperBoundUnrolledCst); 438 else 439 upperBoundUnrolled = ubCstOp; 440 441 // Create constant for 'stepUnrolled'. 442 stepUnrolled = stepCst == stepUnrolledCst 443 ? step 444 : boundsBuilder.create<arith::ConstantIndexOp>( 445 loc, stepUnrolledCst); 446 } else { 447 // Dynamic loop bounds computation. 448 // TODO: Add dynamic asserts for negative lb/ub/step, or 449 // consider using ceilDiv from AffineApplyExpander. 450 auto lowerBound = forOp.getLowerBound(); 451 auto upperBound = forOp.getUpperBound(); 452 Value diff = 453 boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound); 454 Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step); 455 Value unrollFactorCst = 456 boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor); 457 Value tripCountRem = 458 boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst); 459 // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor) 460 Value tripCountEvenMultiple = 461 boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem); 462 // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step 463 upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>( 464 loc, lowerBound, 465 boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step)); 466 // Scale 'step' by 'unrollFactor'. 467 stepUnrolled = 468 boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst); 469 } 470 471 // Create epilogue clean up loop starting at 'upperBoundUnrolled'. 472 if (generateEpilogueLoop) { 473 OpBuilder epilogueBuilder(forOp->getContext()); 474 epilogueBuilder.setInsertionPoint(forOp->getBlock(), 475 std::next(Block::iterator(forOp))); 476 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp)); 477 epilogueForOp.setLowerBound(upperBoundUnrolled); 478 479 // Update uses of loop results. 480 auto results = forOp.getResults(); 481 auto epilogueResults = epilogueForOp.getResults(); 482 483 for (auto e : llvm::zip(results, epilogueResults)) { 484 std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); 485 } 486 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(), 487 epilogueForOp.getNumIterOperands(), results); 488 (void)promoteIfSingleIteration(epilogueForOp); 489 } 490 491 // Create unrolled loop. 492 forOp.setUpperBound(upperBoundUnrolled); 493 forOp.setStep(stepUnrolled); 494 495 auto iterArgs = ValueRange(forOp.getRegionIterArgs()); 496 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands(); 497 498 generateUnrolledLoop( 499 forOp.getBody(), forOp.getInductionVar(), unrollFactor, 500 [&](unsigned i, Value iv, OpBuilder b) { 501 // iv' = iv + step * i; 502 auto stride = b.create<arith::MulIOp>( 503 loc, step, b.create<arith::ConstantIndexOp>(loc, i)); 504 return b.create<arith::AddIOp>(loc, iv, stride); 505 }, 506 annotateFn, iterArgs, yieldedValues); 507 // Promote the loop body up if this has turned into a single iteration loop. 508 (void)promoteIfSingleIteration(forOp); 509 return success(); 510 } 511 512 /// Return the new lower bound, upper bound, and step in that order. Insert any 513 /// additional bounds calculations before the given builder and any additional 514 /// conversion back to the original loop induction value inside the given Block. 515 static LoopParams normalizeLoop(OpBuilder &boundsBuilder, 516 OpBuilder &insideLoopBuilder, Location loc, 517 Value lowerBound, Value upperBound, Value step, 518 Value inductionVar) { 519 // Check if the loop is already known to have a constant zero lower bound or 520 // a constant one step. 521 bool isZeroBased = false; 522 if (auto ubCst = lowerBound.getDefiningOp<arith::ConstantIndexOp>()) 523 isZeroBased = ubCst.value() == 0; 524 525 bool isStepOne = false; 526 if (auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>()) 527 isStepOne = stepCst.value() == 1; 528 529 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step) 530 // assuming the step is strictly positive. Update the bounds and the step 531 // of the loop to go from 0 to the number of iterations, if necessary. 532 // TODO: introduce support for negative steps or emit dynamic asserts 533 // on step positivity, whatever gets implemented first. 534 if (isZeroBased && isStepOne) 535 return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound, 536 /*step=*/step}; 537 538 Value diff = boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound); 539 Value newUpperBound = ceilDivPositive(boundsBuilder, loc, diff, step); 540 541 Value newLowerBound = 542 isZeroBased ? lowerBound 543 : boundsBuilder.create<arith::ConstantIndexOp>(loc, 0); 544 Value newStep = 545 isStepOne ? step : boundsBuilder.create<arith::ConstantIndexOp>(loc, 1); 546 547 // Insert code computing the value of the original loop induction variable 548 // from the "normalized" one. 549 Value scaled = 550 isStepOne 551 ? inductionVar 552 : insideLoopBuilder.create<arith::MulIOp>(loc, inductionVar, step); 553 Value shifted = 554 isZeroBased 555 ? scaled 556 : insideLoopBuilder.create<arith::AddIOp>(loc, scaled, lowerBound); 557 558 SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(), 559 shifted.getDefiningOp()}; 560 inductionVar.replaceAllUsesExcept(shifted, preserve); 561 return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound, 562 /*step=*/newStep}; 563 } 564 565 /// Transform a loop with a strictly positive step 566 /// for %i = %lb to %ub step %s 567 /// into a 0-based loop with step 1 568 /// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 { 569 /// %i = %ii * %s + %lb 570 /// Insert the induction variable remapping in the body of `inner`, which is 571 /// expected to be either `loop` or another loop perfectly nested under `loop`. 572 /// Insert the definition of new bounds immediate before `outer`, which is 573 /// expected to be either `loop` or its parent in the loop nest. 574 static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) { 575 OpBuilder builder(outer); 576 OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody()); 577 auto loopPieces = normalizeLoop(builder, innerBuilder, loop.getLoc(), 578 loop.getLowerBound(), loop.getUpperBound(), 579 loop.getStep(), loop.getInductionVar()); 580 581 loop.setLowerBound(loopPieces.lowerBound); 582 loop.setUpperBound(loopPieces.upperBound); 583 loop.setStep(loopPieces.step); 584 } 585 586 void mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) { 587 if (loops.size() < 2) 588 return; 589 590 scf::ForOp innermost = loops.back(); 591 scf::ForOp outermost = loops.front(); 592 593 // 1. Make sure all loops iterate from 0 to upperBound with step 1. This 594 // allows the following code to assume upperBound is the number of iterations. 595 for (auto loop : loops) 596 normalizeLoop(loop, outermost, innermost); 597 598 // 2. Emit code computing the upper bound of the coalesced loop as product 599 // of the number of iterations of all loops. 600 OpBuilder builder(outermost); 601 Location loc = outermost.getLoc(); 602 Value upperBound = outermost.getUpperBound(); 603 for (auto loop : loops.drop_front()) 604 upperBound = 605 builder.create<arith::MulIOp>(loc, upperBound, loop.getUpperBound()); 606 outermost.setUpperBound(upperBound); 607 608 builder.setInsertionPointToStart(outermost.getBody()); 609 610 // 3. Remap induction variables. For each original loop, the value of the 611 // induction variable can be obtained by dividing the induction variable of 612 // the linearized loop by the total number of iterations of the loops nested 613 // in it modulo the number of iterations in this loop (remove the values 614 // related to the outer loops): 615 // iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i. 616 // Compute these iteratively from the innermost loop by creating a "running 617 // quotient" of division by the range. 618 Value previous = outermost.getInductionVar(); 619 for (unsigned i = 0, e = loops.size(); i < e; ++i) { 620 unsigned idx = loops.size() - i - 1; 621 if (i != 0) 622 previous = builder.create<arith::DivSIOp>(loc, previous, 623 loops[idx + 1].getUpperBound()); 624 625 Value iv = (i == e - 1) ? previous 626 : builder.create<arith::RemSIOp>( 627 loc, previous, loops[idx].getUpperBound()); 628 replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv, 629 loops.back().getRegion()); 630 } 631 632 // 4. Move the operations from the innermost just above the second-outermost 633 // loop, delete the extra terminator and the second-outermost loop. 634 scf::ForOp second = loops[1]; 635 innermost.getBody()->back().erase(); 636 outermost.getBody()->getOperations().splice( 637 Block::iterator(second.getOperation()), 638 innermost.getBody()->getOperations()); 639 second.erase(); 640 } 641 642 void mlir::collapseParallelLoops( 643 scf::ParallelOp loops, ArrayRef<std::vector<unsigned>> combinedDimensions) { 644 OpBuilder outsideBuilder(loops); 645 Location loc = loops.getLoc(); 646 647 // Presort combined dimensions. 648 auto sortedDimensions = llvm::to_vector<3>(combinedDimensions); 649 for (auto &dims : sortedDimensions) 650 std::sort(dims.begin(), dims.end()); 651 652 // Normalize ParallelOp's iteration pattern. 653 SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps, 654 normalizedUpperBounds; 655 for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) { 656 OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody()); 657 auto resultBounds = 658 normalizeLoop(outsideBuilder, insideLoopBuilder, loc, 659 loops.getLowerBound()[i], loops.getUpperBound()[i], 660 loops.getStep()[i], loops.getBody()->getArgument(i)); 661 662 normalizedLowerBounds.push_back(resultBounds.lowerBound); 663 normalizedUpperBounds.push_back(resultBounds.upperBound); 664 normalizedSteps.push_back(resultBounds.step); 665 } 666 667 // Combine iteration spaces. 668 SmallVector<Value, 3> lowerBounds, upperBounds, steps; 669 auto cst0 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 0); 670 auto cst1 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1); 671 for (unsigned i = 0, e = sortedDimensions.size(); i < e; ++i) { 672 Value newUpperBound = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1); 673 for (auto idx : sortedDimensions[i]) { 674 newUpperBound = outsideBuilder.create<arith::MulIOp>( 675 loc, newUpperBound, normalizedUpperBounds[idx]); 676 } 677 lowerBounds.push_back(cst0); 678 steps.push_back(cst1); 679 upperBounds.push_back(newUpperBound); 680 } 681 682 // Create new ParallelLoop with conversions to the original induction values. 683 // The loop below uses divisions to get the relevant range of values in the 684 // new induction value that represent each range of the original induction 685 // value. The remainders then determine based on that range, which iteration 686 // of the original induction value this represents. This is a normalized value 687 // that is un-normalized already by the previous logic. 688 auto newPloop = outsideBuilder.create<scf::ParallelOp>( 689 loc, lowerBounds, upperBounds, steps, 690 [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) { 691 for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) { 692 Value previous = ploopIVs[i]; 693 unsigned numberCombinedDimensions = combinedDimensions[i].size(); 694 // Iterate over all except the last induction value. 695 for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) { 696 unsigned idx = combinedDimensions[i][j]; 697 698 // Determine the current induction value's current loop iteration 699 Value iv = insideBuilder.create<arith::RemSIOp>( 700 loc, previous, normalizedUpperBounds[idx]); 701 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv, 702 loops.getRegion()); 703 704 // Remove the effect of the current induction value to prepare for 705 // the next value. 706 previous = insideBuilder.create<arith::DivSIOp>( 707 loc, previous, normalizedUpperBounds[idx]); 708 } 709 710 // The final induction value is just the remaining value. 711 unsigned idx = combinedDimensions[i][0]; 712 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), 713 previous, loops.getRegion()); 714 } 715 }); 716 717 // Replace the old loop with the new loop. 718 loops.getBody()->back().erase(); 719 newPloop.getBody()->getOperations().splice( 720 Block::iterator(newPloop.getBody()->back()), 721 loops.getBody()->getOperations()); 722 loops.erase(); 723 } 724 725 // Hoist the ops within `outer` that appear before `inner`. 726 // Such ops include the ops that have been introduced by parametric tiling. 727 // Ops that come from triangular loops (i.e. that belong to the program slice 728 // rooted at `outer`) and ops that have side effects cannot be hoisted. 729 // Return failure when any op fails to hoist. 730 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) { 731 SetVector<Operation *> forwardSlice; 732 getForwardSlice( 733 outer.getInductionVar(), &forwardSlice, 734 [&inner](Operation *op) { return op != inner.getOperation(); }); 735 LogicalResult status = success(); 736 SmallVector<Operation *, 8> toHoist; 737 for (auto &op : outer.getBody()->without_terminator()) { 738 // Stop when encountering the inner loop. 739 if (&op == inner.getOperation()) 740 break; 741 // Skip over non-hoistable ops. 742 if (forwardSlice.count(&op) > 0) { 743 status = failure(); 744 continue; 745 } 746 // Skip intermediate scf::ForOp, these are not considered a failure. 747 if (isa<scf::ForOp>(op)) 748 continue; 749 // Skip other ops with regions. 750 if (op.getNumRegions() > 0) { 751 status = failure(); 752 continue; 753 } 754 // Skip if op has side effects. 755 // TODO: loads to immutable memory regions are ok. 756 if (!MemoryEffectOpInterface::hasNoEffect(&op)) { 757 status = failure(); 758 continue; 759 } 760 toHoist.push_back(&op); 761 } 762 auto *outerForOp = outer.getOperation(); 763 for (auto *op : toHoist) 764 op->moveBefore(outerForOp); 765 return status; 766 } 767 768 // Traverse the interTile and intraTile loops and try to hoist ops such that 769 // bands of perfectly nested loops are isolated. 770 // Return failure if either perfect interTile or perfect intraTile bands cannot 771 // be formed. 772 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) { 773 LogicalResult status = success(); 774 const Loops &interTile = tileLoops.first; 775 const Loops &intraTile = tileLoops.second; 776 auto size = interTile.size(); 777 assert(size == intraTile.size()); 778 if (size <= 1) 779 return success(); 780 for (unsigned s = 1; s < size; ++s) 781 status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s]) 782 : failure(); 783 for (unsigned s = 1; s < size; ++s) 784 status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s]) 785 : failure(); 786 return status; 787 } 788 789 /// Collect perfectly nested loops starting from `rootForOps`. Loops are 790 /// perfectly nested if each loop is the first and only non-terminator operation 791 /// in the parent loop. Collect at most `maxLoops` loops and append them to 792 /// `forOps`. 793 template <typename T> 794 static void getPerfectlyNestedLoopsImpl( 795 SmallVectorImpl<T> &forOps, T rootForOp, 796 unsigned maxLoops = std::numeric_limits<unsigned>::max()) { 797 for (unsigned i = 0; i < maxLoops; ++i) { 798 forOps.push_back(rootForOp); 799 Block &body = rootForOp.getRegion().front(); 800 if (body.begin() != std::prev(body.end(), 2)) 801 return; 802 803 rootForOp = dyn_cast<T>(&body.front()); 804 if (!rootForOp) 805 return; 806 } 807 } 808 809 static Loops stripmineSink(scf::ForOp forOp, Value factor, 810 ArrayRef<scf::ForOp> targets) { 811 auto originalStep = forOp.getStep(); 812 auto iv = forOp.getInductionVar(); 813 814 OpBuilder b(forOp); 815 forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor)); 816 817 Loops innerLoops; 818 for (auto t : targets) { 819 // Save information for splicing ops out of t when done 820 auto begin = t.getBody()->begin(); 821 auto nOps = t.getBody()->getOperations().size(); 822 823 // Insert newForOp before the terminator of `t`. 824 auto b = OpBuilder::atBlockTerminator((t.getBody())); 825 Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep()); 826 Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt, 827 forOp.getUpperBound(), stepped); 828 Value ub = b.create<arith::SelectOp>(t.getLoc(), less, 829 forOp.getUpperBound(), stepped); 830 831 // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. 832 auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep); 833 newForOp.getBody()->getOperations().splice( 834 newForOp.getBody()->getOperations().begin(), 835 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1)); 836 replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(), 837 newForOp.getRegion()); 838 839 innerLoops.push_back(newForOp); 840 } 841 842 return innerLoops; 843 } 844 845 // Stripmines a `forOp` by `factor` and sinks it under a single `target`. 846 // Returns the new for operation, nested immediately under `target`. 847 template <typename SizeType> 848 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor, 849 scf::ForOp target) { 850 // TODO: Use cheap structural assertions that targets are nested under 851 // forOp and that targets are not nested under each other when DominanceInfo 852 // exposes the capability. It seems overkill to construct a whole function 853 // dominance tree at this point. 854 auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target)); 855 assert(res.size() == 1 && "Expected 1 inner forOp"); 856 return res[0]; 857 } 858 859 SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps, 860 ArrayRef<Value> sizes, 861 ArrayRef<scf::ForOp> targets) { 862 SmallVector<SmallVector<scf::ForOp, 8>, 8> res; 863 SmallVector<scf::ForOp, 8> currentTargets(targets.begin(), targets.end()); 864 for (auto it : llvm::zip(forOps, sizes)) { 865 auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets); 866 res.push_back(step); 867 currentTargets = step; 868 } 869 return res; 870 } 871 872 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes, 873 scf::ForOp target) { 874 SmallVector<scf::ForOp, 8> res; 875 for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) { 876 assert(loops.size() == 1); 877 res.push_back(loops[0]); 878 } 879 return res; 880 } 881 882 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) { 883 // Collect perfectly nested loops. If more size values provided than nested 884 // loops available, truncate `sizes`. 885 SmallVector<scf::ForOp, 4> forOps; 886 forOps.reserve(sizes.size()); 887 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size()); 888 if (forOps.size() < sizes.size()) 889 sizes = sizes.take_front(forOps.size()); 890 891 return ::tile(forOps, sizes, forOps.back()); 892 } 893 894 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops, 895 scf::ForOp root) { 896 getPerfectlyNestedLoopsImpl(nestedLoops, root); 897 } 898 899 TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, 900 ArrayRef<int64_t> sizes) { 901 // Collect perfectly nested loops. If more size values provided than nested 902 // loops available, truncate `sizes`. 903 SmallVector<scf::ForOp, 4> forOps; 904 forOps.reserve(sizes.size()); 905 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size()); 906 if (forOps.size() < sizes.size()) 907 sizes = sizes.take_front(forOps.size()); 908 909 // Compute the tile sizes such that i-th outer loop executes size[i] 910 // iterations. Given that the loop current executes 911 // numIterations = ceildiv((upperBound - lowerBound), step) 912 // iterations, we need to tile with size ceildiv(numIterations, size[i]). 913 SmallVector<Value, 4> tileSizes; 914 tileSizes.reserve(sizes.size()); 915 for (unsigned i = 0, e = sizes.size(); i < e; ++i) { 916 assert(sizes[i] > 0 && "expected strictly positive size for strip-mining"); 917 918 auto forOp = forOps[i]; 919 OpBuilder builder(forOp); 920 auto loc = forOp.getLoc(); 921 Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(), 922 forOp.getLowerBound()); 923 Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep()); 924 Value iterationsPerBlock = 925 ceilDivPositive(builder, loc, numIterations, sizes[i]); 926 tileSizes.push_back(iterationsPerBlock); 927 } 928 929 // Call parametric tiling with the given sizes. 930 auto intraTile = tile(forOps, tileSizes, forOps.back()); 931 TileLoops tileLoops = std::make_pair(forOps, intraTile); 932 933 // TODO: for now we just ignore the result of band isolation. 934 // In the future, mapping decisions may be impacted by the ability to 935 // isolate perfectly nested bands. 936 (void)tryIsolateBands(tileLoops); 937 938 return tileLoops; 939 } 940