1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Utils/Utils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Support/MathExtras.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 
27 using namespace mlir;
28 
29 namespace {
30 // This structure is to pass and return sets of loop parameters without
31 // confusing the order.
32 struct LoopParams {
33   Value lowerBound;
34   Value upperBound;
35   Value step;
36 };
37 } // namespace
38 
39 scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
40                                     ValueRange newIterOperands,
41                                     ValueRange newYieldedValues,
42                                     bool replaceLoopResults) {
43   assert(newIterOperands.size() == newYieldedValues.size() &&
44          "newIterOperands must be of the same size as newYieldedValues");
45 
46   // Create a new loop before the existing one, with the extra operands.
47   OpBuilder::InsertionGuard g(b);
48   b.setInsertionPoint(loop);
49   auto operands = llvm::to_vector<4>(loop.getIterOperands());
50   operands.append(newIterOperands.begin(), newIterOperands.end());
51   scf::ForOp newLoop =
52       b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(),
53                            loop.getUpperBound(), loop.getStep(), operands);
54 
55   auto &loopBody = *loop.getBody();
56   auto &newLoopBody = *newLoop.getBody();
57   // Clone / erase the yield inside the original loop to both:
58   //   1. augment its operands with the newYieldedValues.
59   //   2. automatically apply the BlockAndValueMapping on its operand
60   auto yield = cast<scf::YieldOp>(loopBody.getTerminator());
61   b.setInsertionPoint(yield);
62   auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
63   yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
64   auto newYield = b.create<scf::YieldOp>(yield.getLoc(), yieldOperands);
65 
66   // Clone the loop body with remaps.
67   BlockAndValueMapping bvm;
68   // a. remap the induction variable.
69   bvm.map(loop.getInductionVar(), newLoop.getInductionVar());
70   // b. remap the BB args.
71   bvm.map(loopBody.getArguments(),
72           newLoopBody.getArguments().take_front(loopBody.getNumArguments()));
73   // c. remap the iter args.
74   bvm.map(newIterOperands,
75           newLoop.getRegionIterArgs().take_back(newIterOperands.size()));
76   b.setInsertionPointToStart(&newLoopBody);
77   // Skip the original yield terminator which does not have enough operands.
78   for (auto &o : loopBody.without_terminator())
79     b.clone(o, bvm);
80 
81   // Replace `loop`'s results if requested.
82   if (replaceLoopResults) {
83     for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
84                                                     loop.getNumResults())))
85       std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
86   }
87 
88   // TODO: this is unsafe in the context of a PatternRewrite.
89   newYield.erase();
90 
91   return newLoop;
92 }
93 
94 /// Outline a region with a single block into a new FuncOp.
95 /// Assumes the FuncOp result types is the type of the yielded operands of the
96 /// single block. This constraint makes it easy to determine the result.
97 /// This method also clones the `arith::ConstantIndexOp` at the start of
98 /// `outlinedFuncBody` to alloc simple canonicalizations.
99 // TODO: support more than single-block regions.
100 // TODO: more flexible constant handling.
101 FailureOr<FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
102                                                  Location loc, Region &region,
103                                                  StringRef funcName) {
104   assert(!funcName.empty() && "funcName cannot be empty");
105   if (!region.hasOneBlock())
106     return failure();
107 
108   Block *originalBlock = &region.front();
109   Operation *originalTerminator = originalBlock->getTerminator();
110 
111   // Outline before current function.
112   OpBuilder::InsertionGuard g(rewriter);
113   rewriter.setInsertionPoint(region.getParentOfType<FuncOp>());
114 
115   SetVector<Value> captures;
116   getUsedValuesDefinedAbove(region, captures);
117 
118   ValueRange outlinedValues(captures.getArrayRef());
119   SmallVector<Type> outlinedFuncArgTypes;
120   SmallVector<Location> outlinedFuncArgLocs;
121   // Region's arguments are exactly the first block's arguments as per
122   // Region::getArguments().
123   // Func's arguments are cat(regions's arguments, captures arguments).
124   for (BlockArgument arg : region.getArguments()) {
125     outlinedFuncArgTypes.push_back(arg.getType());
126     outlinedFuncArgLocs.push_back(arg.getLoc());
127   }
128   for (Value value : outlinedValues) {
129     outlinedFuncArgTypes.push_back(value.getType());
130     outlinedFuncArgLocs.push_back(value.getLoc());
131   }
132   FunctionType outlinedFuncType =
133       FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
134                         originalTerminator->getOperandTypes());
135   auto outlinedFunc = rewriter.create<FuncOp>(loc, funcName, outlinedFuncType);
136   Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
137 
138   // Merge blocks while replacing the original block operands.
139   // Warning: `mergeBlocks` erases the original block, reconstruct it later.
140   int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
141   auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
142   {
143     OpBuilder::InsertionGuard g(rewriter);
144     rewriter.setInsertionPointToEnd(outlinedFuncBody);
145     rewriter.mergeBlocks(
146         originalBlock, outlinedFuncBody,
147         outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
148     // Explicitly set up a new ReturnOp terminator.
149     rewriter.setInsertionPointToEnd(outlinedFuncBody);
150     rewriter.create<ReturnOp>(loc, originalTerminator->getResultTypes(),
151                               originalTerminator->getOperands());
152   }
153 
154   // Reconstruct the block that was deleted and add a
155   // terminator(call_results).
156   Block *newBlock = rewriter.createBlock(
157       &region, region.begin(),
158       TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
159       ArrayRef<Location>(outlinedFuncArgLocs)
160           .take_front(numOriginalBlockArguments));
161   {
162     OpBuilder::InsertionGuard g(rewriter);
163     rewriter.setInsertionPointToEnd(newBlock);
164     SmallVector<Value> callValues;
165     llvm::append_range(callValues, newBlock->getArguments());
166     llvm::append_range(callValues, outlinedValues);
167     Operation *call = rewriter.create<CallOp>(loc, outlinedFunc, callValues);
168 
169     // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
170     // Clone `originalTerminator` to take the callOp results then erase it from
171     // `outlinedFuncBody`.
172     BlockAndValueMapping bvm;
173     bvm.map(originalTerminator->getOperands(), call->getResults());
174     rewriter.clone(*originalTerminator, bvm);
175     rewriter.eraseOp(originalTerminator);
176   }
177 
178   // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
179   // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
180   for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
181                                                outlinedValues.size()))) {
182     Value orig = std::get<0>(it);
183     Value repl = std::get<1>(it);
184     {
185       OpBuilder::InsertionGuard g(rewriter);
186       rewriter.setInsertionPointToStart(outlinedFuncBody);
187       if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
188         BlockAndValueMapping bvm;
189         repl = rewriter.clone(*cst, bvm)->getResult(0);
190       }
191     }
192     orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
193       return outlinedFunc->isProperAncestor(opOperand.getOwner());
194     });
195   }
196 
197   return outlinedFunc;
198 }
199 
200 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp, FuncOp *thenFn,
201                                 StringRef thenFnName, FuncOp *elseFn,
202                                 StringRef elseFnName) {
203   IRRewriter rewriter(b);
204   Location loc = ifOp.getLoc();
205   FailureOr<FuncOp> outlinedFuncOpOrFailure;
206   if (thenFn && !ifOp.getThenRegion().empty()) {
207     outlinedFuncOpOrFailure = outlineSingleBlockRegion(
208         rewriter, loc, ifOp.getThenRegion(), thenFnName);
209     if (failed(outlinedFuncOpOrFailure))
210       return failure();
211     *thenFn = *outlinedFuncOpOrFailure;
212   }
213   if (elseFn && !ifOp.getElseRegion().empty()) {
214     outlinedFuncOpOrFailure = outlineSingleBlockRegion(
215         rewriter, loc, ifOp.getElseRegion(), elseFnName);
216     if (failed(outlinedFuncOpOrFailure))
217       return failure();
218     *elseFn = *outlinedFuncOpOrFailure;
219   }
220   return success();
221 }
222 
223 bool mlir::getInnermostParallelLoops(Operation *rootOp,
224                                      SmallVectorImpl<scf::ParallelOp> &result) {
225   assert(rootOp != nullptr && "Root operation must not be a nullptr.");
226   bool rootEnclosesPloops = false;
227   for (Region &region : rootOp->getRegions()) {
228     for (Block &block : region.getBlocks()) {
229       for (Operation &op : block) {
230         bool enclosesPloops = getInnermostParallelLoops(&op, result);
231         rootEnclosesPloops |= enclosesPloops;
232         if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
233           rootEnclosesPloops = true;
234 
235           // Collect parallel loop if it is an innermost one.
236           if (!enclosesPloops)
237             result.push_back(ploop);
238         }
239       }
240     }
241   }
242   return rootEnclosesPloops;
243 }
244 
245 // Build the IR that performs ceil division of a positive value by a constant:
246 //    ceildiv(a, B) = divis(a + (B-1), B)
247 // where divis is rounding-to-zero division.
248 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
249                              int64_t divisor) {
250   assert(divisor > 0 && "expected positive divisor");
251   assert(dividend.getType().isIndex() && "expected index-typed value");
252 
253   Value divisorMinusOneCst =
254       builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
255   Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
256   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
257   return builder.create<arith::DivSIOp>(loc, sum, divisorCst);
258 }
259 
260 // Build the IR that performs ceil division of a positive value by another
261 // positive value:
262 //    ceildiv(a, b) = divis(a + (b - 1), b)
263 // where divis is rounding-to-zero division.
264 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
265                              Value divisor) {
266   assert(dividend.getType().isIndex() && "expected index-typed value");
267 
268   Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
269   Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
270   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
271   return builder.create<arith::DivSIOp>(loc, sum, divisor);
272 }
273 
274 /// Helper to replace uses of loop carried values (iter_args) and loop
275 /// yield values while promoting single iteration scf.for ops.
276 static void replaceIterArgsAndYieldResults(scf::ForOp forOp) {
277   // Replace uses of iter arguments with iter operands (initial values).
278   auto iterOperands = forOp.getIterOperands();
279   auto iterArgs = forOp.getRegionIterArgs();
280   for (auto e : llvm::zip(iterOperands, iterArgs))
281     std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
282 
283   // Replace uses of loop results with the values yielded by the loop.
284   auto outerResults = forOp.getResults();
285   auto innerResults = forOp.getBody()->getTerminator()->getOperands();
286   for (auto e : llvm::zip(outerResults, innerResults))
287     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
288 }
289 
290 /// Promotes the loop body of a forOp to its containing block if the forOp
291 /// it can be determined that the loop has a single iteration.
292 LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
293   auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
294   auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
295   auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
296   if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 ||
297       ubCstOp.value() < 0 || stepCstOp.value() < 0)
298     return failure();
299   int64_t tripCount =
300       mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value());
301   if (tripCount != 1)
302     return failure();
303   auto iv = forOp.getInductionVar();
304   iv.replaceAllUsesWith(lbCstOp);
305 
306   replaceIterArgsAndYieldResults(forOp);
307 
308   // Move the loop body operations, except for its terminator, to the loop's
309   // containing block.
310   auto *parentBlock = forOp->getBlock();
311   forOp.getBody()->getTerminator()->erase();
312   parentBlock->getOperations().splice(Block::iterator(forOp),
313                                       forOp.getBody()->getOperations());
314   forOp.erase();
315   return success();
316 }
317 
318 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
319 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
320 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
321 /// unrolled iteration using annotateFn.
322 static void generateUnrolledLoop(
323     Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
324     function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
325     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
326     ValueRange iterArgs, ValueRange yieldedValues) {
327   // Builder to insert unrolled bodies just before the terminator of the body of
328   // 'forOp'.
329   auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
330 
331   if (!annotateFn)
332     annotateFn = [](unsigned, Operation *, OpBuilder) {};
333 
334   // Keep a pointer to the last non-terminator operation in the original block
335   // so that we know what to clone (since we are doing this in-place).
336   Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
337 
338   // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
339   SmallVector<Value, 4> lastYielded(yieldedValues);
340 
341   for (unsigned i = 1; i < unrollFactor; i++) {
342     BlockAndValueMapping operandMap;
343 
344     // Prepare operand map.
345     operandMap.map(iterArgs, lastYielded);
346 
347     // If the induction variable is used, create a remapping to the value for
348     // this unrolled instance.
349     if (!forOpIV.use_empty()) {
350       Value ivUnroll = ivRemapFn(i, forOpIV, builder);
351       operandMap.map(forOpIV, ivUnroll);
352     }
353 
354     // Clone the original body of 'forOp'.
355     for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
356       Operation *clonedOp = builder.clone(*it, operandMap);
357       annotateFn(i, clonedOp, builder);
358     }
359 
360     // Update yielded values.
361     for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
362       lastYielded[i] = operandMap.lookup(yieldedValues[i]);
363   }
364 
365   // Make sure we annotate the Ops in the original body. We do this last so that
366   // any annotations are not copied into the cloned Ops above.
367   for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
368     annotateFn(0, &*it, builder);
369 
370   // Update operands of the yield statement.
371   loopBodyBlock->getTerminator()->setOperands(lastYielded);
372 }
373 
374 /// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
375 LogicalResult mlir::loopUnrollByFactor(
376     scf::ForOp forOp, uint64_t unrollFactor,
377     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
378   assert(unrollFactor > 0 && "expected positive unroll factor");
379 
380   // Return if the loop body is empty.
381   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
382     return success();
383 
384   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
385   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
386   OpBuilder boundsBuilder(forOp);
387   auto loc = forOp.getLoc();
388   auto step = forOp.getStep();
389   Value upperBoundUnrolled;
390   Value stepUnrolled;
391   bool generateEpilogueLoop = true;
392 
393   auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
394   auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
395   auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
396   if (lbCstOp && ubCstOp && stepCstOp) {
397     // Constant loop bounds computation.
398     int64_t lbCst = lbCstOp.value();
399     int64_t ubCst = ubCstOp.value();
400     int64_t stepCst = stepCstOp.value();
401     assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
402            "expected positive loop bounds and step");
403     int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
404 
405     if (unrollFactor == 1) {
406       if (tripCount == 1 && failed(promoteIfSingleIteration(forOp)))
407         return failure();
408       return success();
409     }
410 
411     int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
412     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
413     assert(upperBoundUnrolledCst <= ubCst);
414     int64_t stepUnrolledCst = stepCst * unrollFactor;
415 
416     // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
417     generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
418     if (generateEpilogueLoop)
419       upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
420           loc, upperBoundUnrolledCst);
421     else
422       upperBoundUnrolled = ubCstOp;
423 
424     // Create constant for 'stepUnrolled'.
425     stepUnrolled = stepCst == stepUnrolledCst
426                        ? step
427                        : boundsBuilder.create<arith::ConstantIndexOp>(
428                              loc, stepUnrolledCst);
429   } else {
430     // Dynamic loop bounds computation.
431     // TODO: Add dynamic asserts for negative lb/ub/step, or
432     // consider using ceilDiv from AffineApplyExpander.
433     auto lowerBound = forOp.getLowerBound();
434     auto upperBound = forOp.getUpperBound();
435     Value diff =
436         boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
437     Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
438     Value unrollFactorCst =
439         boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
440     Value tripCountRem =
441         boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
442     // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
443     Value tripCountEvenMultiple =
444         boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem);
445     // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
446     upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
447         loc, lowerBound,
448         boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
449     // Scale 'step' by 'unrollFactor'.
450     stepUnrolled =
451         boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
452   }
453 
454   // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
455   if (generateEpilogueLoop) {
456     OpBuilder epilogueBuilder(forOp->getContext());
457     epilogueBuilder.setInsertionPoint(forOp->getBlock(),
458                                       std::next(Block::iterator(forOp)));
459     auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
460     epilogueForOp.setLowerBound(upperBoundUnrolled);
461 
462     // Update uses of loop results.
463     auto results = forOp.getResults();
464     auto epilogueResults = epilogueForOp.getResults();
465     auto epilogueIterOperands = epilogueForOp.getIterOperands();
466 
467     for (auto e : llvm::zip(results, epilogueResults, epilogueIterOperands)) {
468       std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
469       epilogueForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
470     }
471     (void)promoteIfSingleIteration(epilogueForOp);
472   }
473 
474   // Create unrolled loop.
475   forOp.setUpperBound(upperBoundUnrolled);
476   forOp.setStep(stepUnrolled);
477 
478   auto iterArgs = ValueRange(forOp.getRegionIterArgs());
479   auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
480 
481   generateUnrolledLoop(
482       forOp.getBody(), forOp.getInductionVar(), unrollFactor,
483       [&](unsigned i, Value iv, OpBuilder b) {
484         // iv' = iv + step * i;
485         auto stride = b.create<arith::MulIOp>(
486             loc, step, b.create<arith::ConstantIndexOp>(loc, i));
487         return b.create<arith::AddIOp>(loc, iv, stride);
488       },
489       annotateFn, iterArgs, yieldedValues);
490   // Promote the loop body up if this has turned into a single iteration loop.
491   (void)promoteIfSingleIteration(forOp);
492   return success();
493 }
494 
495 /// Return the new lower bound, upper bound, and step in that order. Insert any
496 /// additional bounds calculations before the given builder and any additional
497 /// conversion back to the original loop induction value inside the given Block.
498 static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
499                                 OpBuilder &insideLoopBuilder, Location loc,
500                                 Value lowerBound, Value upperBound, Value step,
501                                 Value inductionVar) {
502   // Check if the loop is already known to have a constant zero lower bound or
503   // a constant one step.
504   bool isZeroBased = false;
505   if (auto ubCst = lowerBound.getDefiningOp<arith::ConstantIndexOp>())
506     isZeroBased = ubCst.value() == 0;
507 
508   bool isStepOne = false;
509   if (auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>())
510     isStepOne = stepCst.value() == 1;
511 
512   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
513   // assuming the step is strictly positive.  Update the bounds and the step
514   // of the loop to go from 0 to the number of iterations, if necessary.
515   // TODO: introduce support for negative steps or emit dynamic asserts
516   // on step positivity, whatever gets implemented first.
517   if (isZeroBased && isStepOne)
518     return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
519             /*step=*/step};
520 
521   Value diff = boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
522   Value newUpperBound = ceilDivPositive(boundsBuilder, loc, diff, step);
523 
524   Value newLowerBound =
525       isZeroBased ? lowerBound
526                   : boundsBuilder.create<arith::ConstantIndexOp>(loc, 0);
527   Value newStep =
528       isStepOne ? step : boundsBuilder.create<arith::ConstantIndexOp>(loc, 1);
529 
530   // Insert code computing the value of the original loop induction variable
531   // from the "normalized" one.
532   Value scaled =
533       isStepOne
534           ? inductionVar
535           : insideLoopBuilder.create<arith::MulIOp>(loc, inductionVar, step);
536   Value shifted =
537       isZeroBased
538           ? scaled
539           : insideLoopBuilder.create<arith::AddIOp>(loc, scaled, lowerBound);
540 
541   SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
542                                        shifted.getDefiningOp()};
543   inductionVar.replaceAllUsesExcept(shifted, preserve);
544   return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
545           /*step=*/newStep};
546 }
547 
548 /// Transform a loop with a strictly positive step
549 ///   for %i = %lb to %ub step %s
550 /// into a 0-based loop with step 1
551 ///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
552 ///     %i = %ii * %s + %lb
553 /// Insert the induction variable remapping in the body of `inner`, which is
554 /// expected to be either `loop` or another loop perfectly nested under `loop`.
555 /// Insert the definition of new bounds immediate before `outer`, which is
556 /// expected to be either `loop` or its parent in the loop nest.
557 static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
558   OpBuilder builder(outer);
559   OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody());
560   auto loopPieces = normalizeLoop(builder, innerBuilder, loop.getLoc(),
561                                   loop.getLowerBound(), loop.getUpperBound(),
562                                   loop.getStep(), loop.getInductionVar());
563 
564   loop.setLowerBound(loopPieces.lowerBound);
565   loop.setUpperBound(loopPieces.upperBound);
566   loop.setStep(loopPieces.step);
567 }
568 
569 void mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
570   if (loops.size() < 2)
571     return;
572 
573   scf::ForOp innermost = loops.back();
574   scf::ForOp outermost = loops.front();
575 
576   // 1. Make sure all loops iterate from 0 to upperBound with step 1.  This
577   // allows the following code to assume upperBound is the number of iterations.
578   for (auto loop : loops)
579     normalizeLoop(loop, outermost, innermost);
580 
581   // 2. Emit code computing the upper bound of the coalesced loop as product
582   // of the number of iterations of all loops.
583   OpBuilder builder(outermost);
584   Location loc = outermost.getLoc();
585   Value upperBound = outermost.getUpperBound();
586   for (auto loop : loops.drop_front())
587     upperBound =
588         builder.create<arith::MulIOp>(loc, upperBound, loop.getUpperBound());
589   outermost.setUpperBound(upperBound);
590 
591   builder.setInsertionPointToStart(outermost.getBody());
592 
593   // 3. Remap induction variables. For each original loop, the value of the
594   // induction variable can be obtained by dividing the induction variable of
595   // the linearized loop by the total number of iterations of the loops nested
596   // in it modulo the number of iterations in this loop (remove the values
597   // related to the outer loops):
598   //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
599   // Compute these iteratively from the innermost loop by creating a "running
600   // quotient" of division by the range.
601   Value previous = outermost.getInductionVar();
602   for (unsigned i = 0, e = loops.size(); i < e; ++i) {
603     unsigned idx = loops.size() - i - 1;
604     if (i != 0)
605       previous = builder.create<arith::DivSIOp>(loc, previous,
606                                                 loops[idx + 1].getUpperBound());
607 
608     Value iv = (i == e - 1) ? previous
609                             : builder.create<arith::RemSIOp>(
610                                   loc, previous, loops[idx].getUpperBound());
611     replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
612                                loops.back().getRegion());
613   }
614 
615   // 4. Move the operations from the innermost just above the second-outermost
616   // loop, delete the extra terminator and the second-outermost loop.
617   scf::ForOp second = loops[1];
618   innermost.getBody()->back().erase();
619   outermost.getBody()->getOperations().splice(
620       Block::iterator(second.getOperation()),
621       innermost.getBody()->getOperations());
622   second.erase();
623 }
624 
625 void mlir::collapseParallelLoops(
626     scf::ParallelOp loops, ArrayRef<std::vector<unsigned>> combinedDimensions) {
627   OpBuilder outsideBuilder(loops);
628   Location loc = loops.getLoc();
629 
630   // Presort combined dimensions.
631   auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
632   for (auto &dims : sortedDimensions)
633     std::sort(dims.begin(), dims.end());
634 
635   // Normalize ParallelOp's iteration pattern.
636   SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
637       normalizedUpperBounds;
638   for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
639     OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody());
640     auto resultBounds =
641         normalizeLoop(outsideBuilder, insideLoopBuilder, loc,
642                       loops.getLowerBound()[i], loops.getUpperBound()[i],
643                       loops.getStep()[i], loops.getBody()->getArgument(i));
644 
645     normalizedLowerBounds.push_back(resultBounds.lowerBound);
646     normalizedUpperBounds.push_back(resultBounds.upperBound);
647     normalizedSteps.push_back(resultBounds.step);
648   }
649 
650   // Combine iteration spaces.
651   SmallVector<Value, 3> lowerBounds, upperBounds, steps;
652   auto cst0 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 0);
653   auto cst1 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
654   for (unsigned i = 0, e = sortedDimensions.size(); i < e; ++i) {
655     Value newUpperBound = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
656     for (auto idx : sortedDimensions[i]) {
657       newUpperBound = outsideBuilder.create<arith::MulIOp>(
658           loc, newUpperBound, normalizedUpperBounds[idx]);
659     }
660     lowerBounds.push_back(cst0);
661     steps.push_back(cst1);
662     upperBounds.push_back(newUpperBound);
663   }
664 
665   // Create new ParallelLoop with conversions to the original induction values.
666   // The loop below uses divisions to get the relevant range of values in the
667   // new induction value that represent each range of the original induction
668   // value. The remainders then determine based on that range, which iteration
669   // of the original induction value this represents. This is a normalized value
670   // that is un-normalized already by the previous logic.
671   auto newPloop = outsideBuilder.create<scf::ParallelOp>(
672       loc, lowerBounds, upperBounds, steps,
673       [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
674         for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
675           Value previous = ploopIVs[i];
676           unsigned numberCombinedDimensions = combinedDimensions[i].size();
677           // Iterate over all except the last induction value.
678           for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
679             unsigned idx = combinedDimensions[i][j];
680 
681             // Determine the current induction value's current loop iteration
682             Value iv = insideBuilder.create<arith::RemSIOp>(
683                 loc, previous, normalizedUpperBounds[idx]);
684             replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
685                                        loops.getRegion());
686 
687             // Remove the effect of the current induction value to prepare for
688             // the next value.
689             previous = insideBuilder.create<arith::DivSIOp>(
690                 loc, previous, normalizedUpperBounds[idx]);
691           }
692 
693           // The final induction value is just the remaining value.
694           unsigned idx = combinedDimensions[i][0];
695           replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
696                                      previous, loops.getRegion());
697         }
698       });
699 
700   // Replace the old loop with the new loop.
701   loops.getBody()->back().erase();
702   newPloop.getBody()->getOperations().splice(
703       Block::iterator(newPloop.getBody()->back()),
704       loops.getBody()->getOperations());
705   loops.erase();
706 }
707 
708 // Hoist the ops within `outer` that appear before `inner`.
709 // Such ops include the ops that have been introduced by parametric tiling.
710 // Ops that come from triangular loops (i.e. that belong to the program slice
711 // rooted at `outer`) and ops that have side effects cannot be hoisted.
712 // Return failure when any op fails to hoist.
713 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
714   SetVector<Operation *> forwardSlice;
715   getForwardSlice(
716       outer.getInductionVar(), &forwardSlice,
717       [&inner](Operation *op) { return op != inner.getOperation(); });
718   LogicalResult status = success();
719   SmallVector<Operation *, 8> toHoist;
720   for (auto &op : outer.getBody()->without_terminator()) {
721     // Stop when encountering the inner loop.
722     if (&op == inner.getOperation())
723       break;
724     // Skip over non-hoistable ops.
725     if (forwardSlice.count(&op) > 0) {
726       status = failure();
727       continue;
728     }
729     // Skip intermediate scf::ForOp, these are not considered a failure.
730     if (isa<scf::ForOp>(op))
731       continue;
732     // Skip other ops with regions.
733     if (op.getNumRegions() > 0) {
734       status = failure();
735       continue;
736     }
737     // Skip if op has side effects.
738     // TODO: loads to immutable memory regions are ok.
739     if (!MemoryEffectOpInterface::hasNoEffect(&op)) {
740       status = failure();
741       continue;
742     }
743     toHoist.push_back(&op);
744   }
745   auto *outerForOp = outer.getOperation();
746   for (auto *op : toHoist)
747     op->moveBefore(outerForOp);
748   return status;
749 }
750 
751 // Traverse the interTile and intraTile loops and try to hoist ops such that
752 // bands of perfectly nested loops are isolated.
753 // Return failure if either perfect interTile or perfect intraTile bands cannot
754 // be formed.
755 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
756   LogicalResult status = success();
757   const Loops &interTile = tileLoops.first;
758   const Loops &intraTile = tileLoops.second;
759   auto size = interTile.size();
760   assert(size == intraTile.size());
761   if (size <= 1)
762     return success();
763   for (unsigned s = 1; s < size; ++s)
764     status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
765                                : failure();
766   for (unsigned s = 1; s < size; ++s)
767     status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
768                                : failure();
769   return status;
770 }
771 
772 /// Collect perfectly nested loops starting from `rootForOps`.  Loops are
773 /// perfectly nested if each loop is the first and only non-terminator operation
774 /// in the parent loop.  Collect at most `maxLoops` loops and append them to
775 /// `forOps`.
776 template <typename T>
777 static void getPerfectlyNestedLoopsImpl(
778     SmallVectorImpl<T> &forOps, T rootForOp,
779     unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
780   for (unsigned i = 0; i < maxLoops; ++i) {
781     forOps.push_back(rootForOp);
782     Block &body = rootForOp.getRegion().front();
783     if (body.begin() != std::prev(body.end(), 2))
784       return;
785 
786     rootForOp = dyn_cast<T>(&body.front());
787     if (!rootForOp)
788       return;
789   }
790 }
791 
792 static Loops stripmineSink(scf::ForOp forOp, Value factor,
793                            ArrayRef<scf::ForOp> targets) {
794   auto originalStep = forOp.getStep();
795   auto iv = forOp.getInductionVar();
796 
797   OpBuilder b(forOp);
798   forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
799 
800   Loops innerLoops;
801   for (auto t : targets) {
802     // Save information for splicing ops out of t when done
803     auto begin = t.getBody()->begin();
804     auto nOps = t.getBody()->getOperations().size();
805 
806     // Insert newForOp before the terminator of `t`.
807     auto b = OpBuilder::atBlockTerminator((t.getBody()));
808     Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
809     Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
810                                          forOp.getUpperBound(), stepped);
811     Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
812                                          forOp.getUpperBound(), stepped);
813 
814     // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
815     auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
816     newForOp.getBody()->getOperations().splice(
817         newForOp.getBody()->getOperations().begin(),
818         t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
819     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
820                                newForOp.getRegion());
821 
822     innerLoops.push_back(newForOp);
823   }
824 
825   return innerLoops;
826 }
827 
828 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
829 // Returns the new for operation, nested immediately under `target`.
830 template <typename SizeType>
831 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
832                                 scf::ForOp target) {
833   // TODO: Use cheap structural assertions that targets are nested under
834   // forOp and that targets are not nested under each other when DominanceInfo
835   // exposes the capability. It seems overkill to construct a whole function
836   // dominance tree at this point.
837   auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
838   assert(res.size() == 1 && "Expected 1 inner forOp");
839   return res[0];
840 }
841 
842 SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
843                                  ArrayRef<Value> sizes,
844                                  ArrayRef<scf::ForOp> targets) {
845   SmallVector<SmallVector<scf::ForOp, 8>, 8> res;
846   SmallVector<scf::ForOp, 8> currentTargets(targets.begin(), targets.end());
847   for (auto it : llvm::zip(forOps, sizes)) {
848     auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
849     res.push_back(step);
850     currentTargets = step;
851   }
852   return res;
853 }
854 
855 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
856                  scf::ForOp target) {
857   SmallVector<scf::ForOp, 8> res;
858   for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
859     assert(loops.size() == 1);
860     res.push_back(loops[0]);
861   }
862   return res;
863 }
864 
865 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
866   // Collect perfectly nested loops.  If more size values provided than nested
867   // loops available, truncate `sizes`.
868   SmallVector<scf::ForOp, 4> forOps;
869   forOps.reserve(sizes.size());
870   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
871   if (forOps.size() < sizes.size())
872     sizes = sizes.take_front(forOps.size());
873 
874   return ::tile(forOps, sizes, forOps.back());
875 }
876 
877 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
878                                    scf::ForOp root) {
879   getPerfectlyNestedLoopsImpl(nestedLoops, root);
880 }
881 
882 TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
883                                        ArrayRef<int64_t> sizes) {
884   // Collect perfectly nested loops.  If more size values provided than nested
885   // loops available, truncate `sizes`.
886   SmallVector<scf::ForOp, 4> forOps;
887   forOps.reserve(sizes.size());
888   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
889   if (forOps.size() < sizes.size())
890     sizes = sizes.take_front(forOps.size());
891 
892   // Compute the tile sizes such that i-th outer loop executes size[i]
893   // iterations.  Given that the loop current executes
894   //   numIterations = ceildiv((upperBound - lowerBound), step)
895   // iterations, we need to tile with size ceildiv(numIterations, size[i]).
896   SmallVector<Value, 4> tileSizes;
897   tileSizes.reserve(sizes.size());
898   for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
899     assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
900 
901     auto forOp = forOps[i];
902     OpBuilder builder(forOp);
903     auto loc = forOp.getLoc();
904     Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(),
905                                                forOp.getLowerBound());
906     Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
907     Value iterationsPerBlock =
908         ceilDivPositive(builder, loc, numIterations, sizes[i]);
909     tileSizes.push_back(iterationsPerBlock);
910   }
911 
912   // Call parametric tiling with the given sizes.
913   auto intraTile = tile(forOps, tileSizes, forOps.back());
914   TileLoops tileLoops = std::make_pair(forOps, intraTile);
915 
916   // TODO: for now we just ignore the result of band isolation.
917   // In the future, mapping decisions may be impacted by the ability to
918   // isolate perfectly nested bands.
919   (void)tryIsolateBands(tileLoops);
920 
921   return tileLoops;
922 }
923