1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Utils/Utils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Support/MathExtras.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 
27 using namespace mlir;
28 
29 namespace {
30 // This structure is to pass and return sets of loop parameters without
31 // confusing the order.
32 struct LoopParams {
33   Value lowerBound;
34   Value upperBound;
35   Value step;
36 };
37 } // namespace
38 
39 scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
40                                     ValueRange newIterOperands,
41                                     ValueRange newYieldedValues,
42                                     bool replaceLoopResults) {
43   assert(newIterOperands.size() == newYieldedValues.size() &&
44          "newIterOperands must be of the same size as newYieldedValues");
45 
46   // Create a new loop before the existing one, with the extra operands.
47   OpBuilder::InsertionGuard g(b);
48   b.setInsertionPoint(loop);
49   auto operands = llvm::to_vector<4>(loop.getIterOperands());
50   operands.append(newIterOperands.begin(), newIterOperands.end());
51   scf::ForOp newLoop =
52       b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(),
53                            loop.getUpperBound(), loop.getStep(), operands);
54 
55   auto &loopBody = *loop.getBody();
56   auto &newLoopBody = *newLoop.getBody();
57   // Clone / erase the yield inside the original loop to both:
58   //   1. augment its operands with the newYieldedValues.
59   //   2. automatically apply the BlockAndValueMapping on its operand
60   auto yield = cast<scf::YieldOp>(loopBody.getTerminator());
61   b.setInsertionPoint(yield);
62   auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
63   yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
64   auto newYield = b.create<scf::YieldOp>(yield.getLoc(), yieldOperands);
65 
66   // Clone the loop body with remaps.
67   BlockAndValueMapping bvm;
68   // a. remap the induction variable.
69   bvm.map(loop.getInductionVar(), newLoop.getInductionVar());
70   // b. remap the BB args.
71   bvm.map(loopBody.getArguments(),
72           newLoopBody.getArguments().take_front(loopBody.getNumArguments()));
73   // c. remap the iter args.
74   bvm.map(newIterOperands,
75           newLoop.getRegionIterArgs().take_back(newIterOperands.size()));
76   b.setInsertionPointToStart(&newLoopBody);
77   // Skip the original yield terminator which does not have enough operands.
78   for (auto &o : loopBody.without_terminator())
79     b.clone(o, bvm);
80 
81   // Replace `loop`'s results if requested.
82   if (replaceLoopResults) {
83     for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
84                                                     loop.getNumResults())))
85       std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
86   }
87 
88   // TODO: this is unsafe in the context of a PatternRewrite.
89   newYield.erase();
90 
91   return newLoop;
92 }
93 
94 /// Outline a region with a single block into a new FuncOp.
95 /// Assumes the FuncOp result types is the type of the yielded operands of the
96 /// single block. This constraint makes it easy to determine the result.
97 /// This method also clones the `arith::ConstantIndexOp` at the start of
98 /// `outlinedFuncBody` to alloc simple canonicalizations.
99 // TODO: support more than single-block regions.
100 // TODO: more flexible constant handling.
101 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
102                                                        Location loc,
103                                                        Region &region,
104                                                        StringRef funcName) {
105   assert(!funcName.empty() && "funcName cannot be empty");
106   if (!region.hasOneBlock())
107     return failure();
108 
109   Block *originalBlock = &region.front();
110   Operation *originalTerminator = originalBlock->getTerminator();
111 
112   // Outline before current function.
113   OpBuilder::InsertionGuard g(rewriter);
114   rewriter.setInsertionPoint(region.getParentOfType<func::FuncOp>());
115 
116   SetVector<Value> captures;
117   getUsedValuesDefinedAbove(region, captures);
118 
119   ValueRange outlinedValues(captures.getArrayRef());
120   SmallVector<Type> outlinedFuncArgTypes;
121   SmallVector<Location> outlinedFuncArgLocs;
122   // Region's arguments are exactly the first block's arguments as per
123   // Region::getArguments().
124   // Func's arguments are cat(regions's arguments, captures arguments).
125   for (BlockArgument arg : region.getArguments()) {
126     outlinedFuncArgTypes.push_back(arg.getType());
127     outlinedFuncArgLocs.push_back(arg.getLoc());
128   }
129   for (Value value : outlinedValues) {
130     outlinedFuncArgTypes.push_back(value.getType());
131     outlinedFuncArgLocs.push_back(value.getLoc());
132   }
133   FunctionType outlinedFuncType =
134       FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
135                         originalTerminator->getOperandTypes());
136   auto outlinedFunc =
137       rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType);
138   Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
139 
140   // Merge blocks while replacing the original block operands.
141   // Warning: `mergeBlocks` erases the original block, reconstruct it later.
142   int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
143   auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
144   {
145     OpBuilder::InsertionGuard g(rewriter);
146     rewriter.setInsertionPointToEnd(outlinedFuncBody);
147     rewriter.mergeBlocks(
148         originalBlock, outlinedFuncBody,
149         outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
150     // Explicitly set up a new ReturnOp terminator.
151     rewriter.setInsertionPointToEnd(outlinedFuncBody);
152     rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(),
153                                     originalTerminator->getOperands());
154   }
155 
156   // Reconstruct the block that was deleted and add a
157   // terminator(call_results).
158   Block *newBlock = rewriter.createBlock(
159       &region, region.begin(),
160       TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
161       ArrayRef<Location>(outlinedFuncArgLocs)
162           .take_front(numOriginalBlockArguments));
163   {
164     OpBuilder::InsertionGuard g(rewriter);
165     rewriter.setInsertionPointToEnd(newBlock);
166     SmallVector<Value> callValues;
167     llvm::append_range(callValues, newBlock->getArguments());
168     llvm::append_range(callValues, outlinedValues);
169     Operation *call =
170         rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
171 
172     // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
173     // Clone `originalTerminator` to take the callOp results then erase it from
174     // `outlinedFuncBody`.
175     BlockAndValueMapping bvm;
176     bvm.map(originalTerminator->getOperands(), call->getResults());
177     rewriter.clone(*originalTerminator, bvm);
178     rewriter.eraseOp(originalTerminator);
179   }
180 
181   // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
182   // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
183   for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
184                                                outlinedValues.size()))) {
185     Value orig = std::get<0>(it);
186     Value repl = std::get<1>(it);
187     {
188       OpBuilder::InsertionGuard g(rewriter);
189       rewriter.setInsertionPointToStart(outlinedFuncBody);
190       if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
191         BlockAndValueMapping bvm;
192         repl = rewriter.clone(*cst, bvm)->getResult(0);
193       }
194     }
195     orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
196       return outlinedFunc->isProperAncestor(opOperand.getOwner());
197     });
198   }
199 
200   return outlinedFunc;
201 }
202 
203 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
204                                 func::FuncOp *thenFn, StringRef thenFnName,
205                                 func::FuncOp *elseFn, StringRef elseFnName) {
206   IRRewriter rewriter(b);
207   Location loc = ifOp.getLoc();
208   FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
209   if (thenFn && !ifOp.getThenRegion().empty()) {
210     outlinedFuncOpOrFailure = outlineSingleBlockRegion(
211         rewriter, loc, ifOp.getThenRegion(), thenFnName);
212     if (failed(outlinedFuncOpOrFailure))
213       return failure();
214     *thenFn = *outlinedFuncOpOrFailure;
215   }
216   if (elseFn && !ifOp.getElseRegion().empty()) {
217     outlinedFuncOpOrFailure = outlineSingleBlockRegion(
218         rewriter, loc, ifOp.getElseRegion(), elseFnName);
219     if (failed(outlinedFuncOpOrFailure))
220       return failure();
221     *elseFn = *outlinedFuncOpOrFailure;
222   }
223   return success();
224 }
225 
226 bool mlir::getInnermostParallelLoops(Operation *rootOp,
227                                      SmallVectorImpl<scf::ParallelOp> &result) {
228   assert(rootOp != nullptr && "Root operation must not be a nullptr.");
229   bool rootEnclosesPloops = false;
230   for (Region &region : rootOp->getRegions()) {
231     for (Block &block : region.getBlocks()) {
232       for (Operation &op : block) {
233         bool enclosesPloops = getInnermostParallelLoops(&op, result);
234         rootEnclosesPloops |= enclosesPloops;
235         if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
236           rootEnclosesPloops = true;
237 
238           // Collect parallel loop if it is an innermost one.
239           if (!enclosesPloops)
240             result.push_back(ploop);
241         }
242       }
243     }
244   }
245   return rootEnclosesPloops;
246 }
247 
248 // Build the IR that performs ceil division of a positive value by a constant:
249 //    ceildiv(a, B) = divis(a + (B-1), B)
250 // where divis is rounding-to-zero division.
251 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
252                              int64_t divisor) {
253   assert(divisor > 0 && "expected positive divisor");
254   assert(dividend.getType().isIndex() && "expected index-typed value");
255 
256   Value divisorMinusOneCst =
257       builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
258   Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
259   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
260   return builder.create<arith::DivSIOp>(loc, sum, divisorCst);
261 }
262 
263 // Build the IR that performs ceil division of a positive value by another
264 // positive value:
265 //    ceildiv(a, b) = divis(a + (b - 1), b)
266 // where divis is rounding-to-zero division.
267 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
268                              Value divisor) {
269   assert(dividend.getType().isIndex() && "expected index-typed value");
270 
271   Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
272   Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
273   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
274   return builder.create<arith::DivSIOp>(loc, sum, divisor);
275 }
276 
277 /// Helper to replace uses of loop carried values (iter_args) and loop
278 /// yield values while promoting single iteration scf.for ops.
279 static void replaceIterArgsAndYieldResults(scf::ForOp forOp) {
280   // Replace uses of iter arguments with iter operands (initial values).
281   auto iterOperands = forOp.getIterOperands();
282   auto iterArgs = forOp.getRegionIterArgs();
283   for (auto e : llvm::zip(iterOperands, iterArgs))
284     std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
285 
286   // Replace uses of loop results with the values yielded by the loop.
287   auto outerResults = forOp.getResults();
288   auto innerResults = forOp.getBody()->getTerminator()->getOperands();
289   for (auto e : llvm::zip(outerResults, innerResults))
290     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
291 }
292 
293 /// Promotes the loop body of a forOp to its containing block if the forOp
294 /// it can be determined that the loop has a single iteration.
295 LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
296   auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
297   auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
298   auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
299   if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 ||
300       ubCstOp.value() < 0 || stepCstOp.value() < 0)
301     return failure();
302   int64_t tripCount =
303       mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value());
304   if (tripCount != 1)
305     return failure();
306   auto iv = forOp.getInductionVar();
307   iv.replaceAllUsesWith(lbCstOp);
308 
309   replaceIterArgsAndYieldResults(forOp);
310 
311   // Move the loop body operations, except for its terminator, to the loop's
312   // containing block.
313   auto *parentBlock = forOp->getBlock();
314   forOp.getBody()->getTerminator()->erase();
315   parentBlock->getOperations().splice(Block::iterator(forOp),
316                                       forOp.getBody()->getOperations());
317   forOp.erase();
318   return success();
319 }
320 
321 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
322 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
323 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
324 /// unrolled iteration using annotateFn.
325 static void generateUnrolledLoop(
326     Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
327     function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
328     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
329     ValueRange iterArgs, ValueRange yieldedValues) {
330   // Builder to insert unrolled bodies just before the terminator of the body of
331   // 'forOp'.
332   auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
333 
334   if (!annotateFn)
335     annotateFn = [](unsigned, Operation *, OpBuilder) {};
336 
337   // Keep a pointer to the last non-terminator operation in the original block
338   // so that we know what to clone (since we are doing this in-place).
339   Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
340 
341   // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
342   SmallVector<Value, 4> lastYielded(yieldedValues);
343 
344   for (unsigned i = 1; i < unrollFactor; i++) {
345     BlockAndValueMapping operandMap;
346 
347     // Prepare operand map.
348     operandMap.map(iterArgs, lastYielded);
349 
350     // If the induction variable is used, create a remapping to the value for
351     // this unrolled instance.
352     if (!forOpIV.use_empty()) {
353       Value ivUnroll = ivRemapFn(i, forOpIV, builder);
354       operandMap.map(forOpIV, ivUnroll);
355     }
356 
357     // Clone the original body of 'forOp'.
358     for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
359       Operation *clonedOp = builder.clone(*it, operandMap);
360       annotateFn(i, clonedOp, builder);
361     }
362 
363     // Update yielded values.
364     for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
365       lastYielded[i] = operandMap.lookup(yieldedValues[i]);
366   }
367 
368   // Make sure we annotate the Ops in the original body. We do this last so that
369   // any annotations are not copied into the cloned Ops above.
370   for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
371     annotateFn(0, &*it, builder);
372 
373   // Update operands of the yield statement.
374   loopBodyBlock->getTerminator()->setOperands(lastYielded);
375 }
376 
377 /// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
378 LogicalResult mlir::loopUnrollByFactor(
379     scf::ForOp forOp, uint64_t unrollFactor,
380     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
381   assert(unrollFactor > 0 && "expected positive unroll factor");
382 
383   // Return if the loop body is empty.
384   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
385     return success();
386 
387   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
388   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
389   OpBuilder boundsBuilder(forOp);
390   auto loc = forOp.getLoc();
391   auto step = forOp.getStep();
392   Value upperBoundUnrolled;
393   Value stepUnrolled;
394   bool generateEpilogueLoop = true;
395 
396   auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
397   auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
398   auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
399   if (lbCstOp && ubCstOp && stepCstOp) {
400     // Constant loop bounds computation.
401     int64_t lbCst = lbCstOp.value();
402     int64_t ubCst = ubCstOp.value();
403     int64_t stepCst = stepCstOp.value();
404     assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
405            "expected positive loop bounds and step");
406     int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
407 
408     if (unrollFactor == 1) {
409       if (tripCount == 1 && failed(promoteIfSingleIteration(forOp)))
410         return failure();
411       return success();
412     }
413 
414     int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
415     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
416     assert(upperBoundUnrolledCst <= ubCst);
417     int64_t stepUnrolledCst = stepCst * unrollFactor;
418 
419     // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
420     generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
421     if (generateEpilogueLoop)
422       upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
423           loc, upperBoundUnrolledCst);
424     else
425       upperBoundUnrolled = ubCstOp;
426 
427     // Create constant for 'stepUnrolled'.
428     stepUnrolled = stepCst == stepUnrolledCst
429                        ? step
430                        : boundsBuilder.create<arith::ConstantIndexOp>(
431                              loc, stepUnrolledCst);
432   } else {
433     // Dynamic loop bounds computation.
434     // TODO: Add dynamic asserts for negative lb/ub/step, or
435     // consider using ceilDiv from AffineApplyExpander.
436     auto lowerBound = forOp.getLowerBound();
437     auto upperBound = forOp.getUpperBound();
438     Value diff =
439         boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
440     Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
441     Value unrollFactorCst =
442         boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
443     Value tripCountRem =
444         boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
445     // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
446     Value tripCountEvenMultiple =
447         boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem);
448     // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
449     upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
450         loc, lowerBound,
451         boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
452     // Scale 'step' by 'unrollFactor'.
453     stepUnrolled =
454         boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
455   }
456 
457   // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
458   if (generateEpilogueLoop) {
459     OpBuilder epilogueBuilder(forOp->getContext());
460     epilogueBuilder.setInsertionPoint(forOp->getBlock(),
461                                       std::next(Block::iterator(forOp)));
462     auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
463     epilogueForOp.setLowerBound(upperBoundUnrolled);
464 
465     // Update uses of loop results.
466     auto results = forOp.getResults();
467     auto epilogueResults = epilogueForOp.getResults();
468     auto epilogueIterOperands = epilogueForOp.getIterOperands();
469 
470     for (auto e : llvm::zip(results, epilogueResults, epilogueIterOperands)) {
471       std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
472       epilogueForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
473     }
474     (void)promoteIfSingleIteration(epilogueForOp);
475   }
476 
477   // Create unrolled loop.
478   forOp.setUpperBound(upperBoundUnrolled);
479   forOp.setStep(stepUnrolled);
480 
481   auto iterArgs = ValueRange(forOp.getRegionIterArgs());
482   auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
483 
484   generateUnrolledLoop(
485       forOp.getBody(), forOp.getInductionVar(), unrollFactor,
486       [&](unsigned i, Value iv, OpBuilder b) {
487         // iv' = iv + step * i;
488         auto stride = b.create<arith::MulIOp>(
489             loc, step, b.create<arith::ConstantIndexOp>(loc, i));
490         return b.create<arith::AddIOp>(loc, iv, stride);
491       },
492       annotateFn, iterArgs, yieldedValues);
493   // Promote the loop body up if this has turned into a single iteration loop.
494   (void)promoteIfSingleIteration(forOp);
495   return success();
496 }
497 
498 /// Return the new lower bound, upper bound, and step in that order. Insert any
499 /// additional bounds calculations before the given builder and any additional
500 /// conversion back to the original loop induction value inside the given Block.
501 static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
502                                 OpBuilder &insideLoopBuilder, Location loc,
503                                 Value lowerBound, Value upperBound, Value step,
504                                 Value inductionVar) {
505   // Check if the loop is already known to have a constant zero lower bound or
506   // a constant one step.
507   bool isZeroBased = false;
508   if (auto ubCst = lowerBound.getDefiningOp<arith::ConstantIndexOp>())
509     isZeroBased = ubCst.value() == 0;
510 
511   bool isStepOne = false;
512   if (auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>())
513     isStepOne = stepCst.value() == 1;
514 
515   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
516   // assuming the step is strictly positive.  Update the bounds and the step
517   // of the loop to go from 0 to the number of iterations, if necessary.
518   // TODO: introduce support for negative steps or emit dynamic asserts
519   // on step positivity, whatever gets implemented first.
520   if (isZeroBased && isStepOne)
521     return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
522             /*step=*/step};
523 
524   Value diff = boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
525   Value newUpperBound = ceilDivPositive(boundsBuilder, loc, diff, step);
526 
527   Value newLowerBound =
528       isZeroBased ? lowerBound
529                   : boundsBuilder.create<arith::ConstantIndexOp>(loc, 0);
530   Value newStep =
531       isStepOne ? step : boundsBuilder.create<arith::ConstantIndexOp>(loc, 1);
532 
533   // Insert code computing the value of the original loop induction variable
534   // from the "normalized" one.
535   Value scaled =
536       isStepOne
537           ? inductionVar
538           : insideLoopBuilder.create<arith::MulIOp>(loc, inductionVar, step);
539   Value shifted =
540       isZeroBased
541           ? scaled
542           : insideLoopBuilder.create<arith::AddIOp>(loc, scaled, lowerBound);
543 
544   SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
545                                        shifted.getDefiningOp()};
546   inductionVar.replaceAllUsesExcept(shifted, preserve);
547   return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
548           /*step=*/newStep};
549 }
550 
551 /// Transform a loop with a strictly positive step
552 ///   for %i = %lb to %ub step %s
553 /// into a 0-based loop with step 1
554 ///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
555 ///     %i = %ii * %s + %lb
556 /// Insert the induction variable remapping in the body of `inner`, which is
557 /// expected to be either `loop` or another loop perfectly nested under `loop`.
558 /// Insert the definition of new bounds immediate before `outer`, which is
559 /// expected to be either `loop` or its parent in the loop nest.
560 static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
561   OpBuilder builder(outer);
562   OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody());
563   auto loopPieces = normalizeLoop(builder, innerBuilder, loop.getLoc(),
564                                   loop.getLowerBound(), loop.getUpperBound(),
565                                   loop.getStep(), loop.getInductionVar());
566 
567   loop.setLowerBound(loopPieces.lowerBound);
568   loop.setUpperBound(loopPieces.upperBound);
569   loop.setStep(loopPieces.step);
570 }
571 
572 void mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
573   if (loops.size() < 2)
574     return;
575 
576   scf::ForOp innermost = loops.back();
577   scf::ForOp outermost = loops.front();
578 
579   // 1. Make sure all loops iterate from 0 to upperBound with step 1.  This
580   // allows the following code to assume upperBound is the number of iterations.
581   for (auto loop : loops)
582     normalizeLoop(loop, outermost, innermost);
583 
584   // 2. Emit code computing the upper bound of the coalesced loop as product
585   // of the number of iterations of all loops.
586   OpBuilder builder(outermost);
587   Location loc = outermost.getLoc();
588   Value upperBound = outermost.getUpperBound();
589   for (auto loop : loops.drop_front())
590     upperBound =
591         builder.create<arith::MulIOp>(loc, upperBound, loop.getUpperBound());
592   outermost.setUpperBound(upperBound);
593 
594   builder.setInsertionPointToStart(outermost.getBody());
595 
596   // 3. Remap induction variables. For each original loop, the value of the
597   // induction variable can be obtained by dividing the induction variable of
598   // the linearized loop by the total number of iterations of the loops nested
599   // in it modulo the number of iterations in this loop (remove the values
600   // related to the outer loops):
601   //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
602   // Compute these iteratively from the innermost loop by creating a "running
603   // quotient" of division by the range.
604   Value previous = outermost.getInductionVar();
605   for (unsigned i = 0, e = loops.size(); i < e; ++i) {
606     unsigned idx = loops.size() - i - 1;
607     if (i != 0)
608       previous = builder.create<arith::DivSIOp>(loc, previous,
609                                                 loops[idx + 1].getUpperBound());
610 
611     Value iv = (i == e - 1) ? previous
612                             : builder.create<arith::RemSIOp>(
613                                   loc, previous, loops[idx].getUpperBound());
614     replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
615                                loops.back().getRegion());
616   }
617 
618   // 4. Move the operations from the innermost just above the second-outermost
619   // loop, delete the extra terminator and the second-outermost loop.
620   scf::ForOp second = loops[1];
621   innermost.getBody()->back().erase();
622   outermost.getBody()->getOperations().splice(
623       Block::iterator(second.getOperation()),
624       innermost.getBody()->getOperations());
625   second.erase();
626 }
627 
628 void mlir::collapseParallelLoops(
629     scf::ParallelOp loops, ArrayRef<std::vector<unsigned>> combinedDimensions) {
630   OpBuilder outsideBuilder(loops);
631   Location loc = loops.getLoc();
632 
633   // Presort combined dimensions.
634   auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
635   for (auto &dims : sortedDimensions)
636     std::sort(dims.begin(), dims.end());
637 
638   // Normalize ParallelOp's iteration pattern.
639   SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
640       normalizedUpperBounds;
641   for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
642     OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody());
643     auto resultBounds =
644         normalizeLoop(outsideBuilder, insideLoopBuilder, loc,
645                       loops.getLowerBound()[i], loops.getUpperBound()[i],
646                       loops.getStep()[i], loops.getBody()->getArgument(i));
647 
648     normalizedLowerBounds.push_back(resultBounds.lowerBound);
649     normalizedUpperBounds.push_back(resultBounds.upperBound);
650     normalizedSteps.push_back(resultBounds.step);
651   }
652 
653   // Combine iteration spaces.
654   SmallVector<Value, 3> lowerBounds, upperBounds, steps;
655   auto cst0 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 0);
656   auto cst1 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
657   for (unsigned i = 0, e = sortedDimensions.size(); i < e; ++i) {
658     Value newUpperBound = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
659     for (auto idx : sortedDimensions[i]) {
660       newUpperBound = outsideBuilder.create<arith::MulIOp>(
661           loc, newUpperBound, normalizedUpperBounds[idx]);
662     }
663     lowerBounds.push_back(cst0);
664     steps.push_back(cst1);
665     upperBounds.push_back(newUpperBound);
666   }
667 
668   // Create new ParallelLoop with conversions to the original induction values.
669   // The loop below uses divisions to get the relevant range of values in the
670   // new induction value that represent each range of the original induction
671   // value. The remainders then determine based on that range, which iteration
672   // of the original induction value this represents. This is a normalized value
673   // that is un-normalized already by the previous logic.
674   auto newPloop = outsideBuilder.create<scf::ParallelOp>(
675       loc, lowerBounds, upperBounds, steps,
676       [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
677         for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
678           Value previous = ploopIVs[i];
679           unsigned numberCombinedDimensions = combinedDimensions[i].size();
680           // Iterate over all except the last induction value.
681           for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
682             unsigned idx = combinedDimensions[i][j];
683 
684             // Determine the current induction value's current loop iteration
685             Value iv = insideBuilder.create<arith::RemSIOp>(
686                 loc, previous, normalizedUpperBounds[idx]);
687             replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
688                                        loops.getRegion());
689 
690             // Remove the effect of the current induction value to prepare for
691             // the next value.
692             previous = insideBuilder.create<arith::DivSIOp>(
693                 loc, previous, normalizedUpperBounds[idx]);
694           }
695 
696           // The final induction value is just the remaining value.
697           unsigned idx = combinedDimensions[i][0];
698           replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
699                                      previous, loops.getRegion());
700         }
701       });
702 
703   // Replace the old loop with the new loop.
704   loops.getBody()->back().erase();
705   newPloop.getBody()->getOperations().splice(
706       Block::iterator(newPloop.getBody()->back()),
707       loops.getBody()->getOperations());
708   loops.erase();
709 }
710 
711 // Hoist the ops within `outer` that appear before `inner`.
712 // Such ops include the ops that have been introduced by parametric tiling.
713 // Ops that come from triangular loops (i.e. that belong to the program slice
714 // rooted at `outer`) and ops that have side effects cannot be hoisted.
715 // Return failure when any op fails to hoist.
716 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
717   SetVector<Operation *> forwardSlice;
718   getForwardSlice(
719       outer.getInductionVar(), &forwardSlice,
720       [&inner](Operation *op) { return op != inner.getOperation(); });
721   LogicalResult status = success();
722   SmallVector<Operation *, 8> toHoist;
723   for (auto &op : outer.getBody()->without_terminator()) {
724     // Stop when encountering the inner loop.
725     if (&op == inner.getOperation())
726       break;
727     // Skip over non-hoistable ops.
728     if (forwardSlice.count(&op) > 0) {
729       status = failure();
730       continue;
731     }
732     // Skip intermediate scf::ForOp, these are not considered a failure.
733     if (isa<scf::ForOp>(op))
734       continue;
735     // Skip other ops with regions.
736     if (op.getNumRegions() > 0) {
737       status = failure();
738       continue;
739     }
740     // Skip if op has side effects.
741     // TODO: loads to immutable memory regions are ok.
742     if (!MemoryEffectOpInterface::hasNoEffect(&op)) {
743       status = failure();
744       continue;
745     }
746     toHoist.push_back(&op);
747   }
748   auto *outerForOp = outer.getOperation();
749   for (auto *op : toHoist)
750     op->moveBefore(outerForOp);
751   return status;
752 }
753 
754 // Traverse the interTile and intraTile loops and try to hoist ops such that
755 // bands of perfectly nested loops are isolated.
756 // Return failure if either perfect interTile or perfect intraTile bands cannot
757 // be formed.
758 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
759   LogicalResult status = success();
760   const Loops &interTile = tileLoops.first;
761   const Loops &intraTile = tileLoops.second;
762   auto size = interTile.size();
763   assert(size == intraTile.size());
764   if (size <= 1)
765     return success();
766   for (unsigned s = 1; s < size; ++s)
767     status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
768                                : failure();
769   for (unsigned s = 1; s < size; ++s)
770     status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
771                                : failure();
772   return status;
773 }
774 
775 /// Collect perfectly nested loops starting from `rootForOps`.  Loops are
776 /// perfectly nested if each loop is the first and only non-terminator operation
777 /// in the parent loop.  Collect at most `maxLoops` loops and append them to
778 /// `forOps`.
779 template <typename T>
780 static void getPerfectlyNestedLoopsImpl(
781     SmallVectorImpl<T> &forOps, T rootForOp,
782     unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
783   for (unsigned i = 0; i < maxLoops; ++i) {
784     forOps.push_back(rootForOp);
785     Block &body = rootForOp.getRegion().front();
786     if (body.begin() != std::prev(body.end(), 2))
787       return;
788 
789     rootForOp = dyn_cast<T>(&body.front());
790     if (!rootForOp)
791       return;
792   }
793 }
794 
795 static Loops stripmineSink(scf::ForOp forOp, Value factor,
796                            ArrayRef<scf::ForOp> targets) {
797   auto originalStep = forOp.getStep();
798   auto iv = forOp.getInductionVar();
799 
800   OpBuilder b(forOp);
801   forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
802 
803   Loops innerLoops;
804   for (auto t : targets) {
805     // Save information for splicing ops out of t when done
806     auto begin = t.getBody()->begin();
807     auto nOps = t.getBody()->getOperations().size();
808 
809     // Insert newForOp before the terminator of `t`.
810     auto b = OpBuilder::atBlockTerminator((t.getBody()));
811     Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
812     Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
813                                          forOp.getUpperBound(), stepped);
814     Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
815                                          forOp.getUpperBound(), stepped);
816 
817     // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
818     auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
819     newForOp.getBody()->getOperations().splice(
820         newForOp.getBody()->getOperations().begin(),
821         t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
822     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
823                                newForOp.getRegion());
824 
825     innerLoops.push_back(newForOp);
826   }
827 
828   return innerLoops;
829 }
830 
831 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
832 // Returns the new for operation, nested immediately under `target`.
833 template <typename SizeType>
834 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
835                                 scf::ForOp target) {
836   // TODO: Use cheap structural assertions that targets are nested under
837   // forOp and that targets are not nested under each other when DominanceInfo
838   // exposes the capability. It seems overkill to construct a whole function
839   // dominance tree at this point.
840   auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
841   assert(res.size() == 1 && "Expected 1 inner forOp");
842   return res[0];
843 }
844 
845 SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
846                                  ArrayRef<Value> sizes,
847                                  ArrayRef<scf::ForOp> targets) {
848   SmallVector<SmallVector<scf::ForOp, 8>, 8> res;
849   SmallVector<scf::ForOp, 8> currentTargets(targets.begin(), targets.end());
850   for (auto it : llvm::zip(forOps, sizes)) {
851     auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
852     res.push_back(step);
853     currentTargets = step;
854   }
855   return res;
856 }
857 
858 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
859                  scf::ForOp target) {
860   SmallVector<scf::ForOp, 8> res;
861   for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
862     assert(loops.size() == 1);
863     res.push_back(loops[0]);
864   }
865   return res;
866 }
867 
868 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
869   // Collect perfectly nested loops.  If more size values provided than nested
870   // loops available, truncate `sizes`.
871   SmallVector<scf::ForOp, 4> forOps;
872   forOps.reserve(sizes.size());
873   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
874   if (forOps.size() < sizes.size())
875     sizes = sizes.take_front(forOps.size());
876 
877   return ::tile(forOps, sizes, forOps.back());
878 }
879 
880 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
881                                    scf::ForOp root) {
882   getPerfectlyNestedLoopsImpl(nestedLoops, root);
883 }
884 
885 TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
886                                        ArrayRef<int64_t> sizes) {
887   // Collect perfectly nested loops.  If more size values provided than nested
888   // loops available, truncate `sizes`.
889   SmallVector<scf::ForOp, 4> forOps;
890   forOps.reserve(sizes.size());
891   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
892   if (forOps.size() < sizes.size())
893     sizes = sizes.take_front(forOps.size());
894 
895   // Compute the tile sizes such that i-th outer loop executes size[i]
896   // iterations.  Given that the loop current executes
897   //   numIterations = ceildiv((upperBound - lowerBound), step)
898   // iterations, we need to tile with size ceildiv(numIterations, size[i]).
899   SmallVector<Value, 4> tileSizes;
900   tileSizes.reserve(sizes.size());
901   for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
902     assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
903 
904     auto forOp = forOps[i];
905     OpBuilder builder(forOp);
906     auto loc = forOp.getLoc();
907     Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(),
908                                                forOp.getLowerBound());
909     Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
910     Value iterationsPerBlock =
911         ceilDivPositive(builder, loc, numIterations, sizes[i]);
912     tileSizes.push_back(iterationsPerBlock);
913   }
914 
915   // Call parametric tiling with the given sizes.
916   auto intraTile = tile(forOps, tileSizes, forOps.back());
917   TileLoops tileLoops = std::make_pair(forOps, intraTile);
918 
919   // TODO: for now we just ignore the result of band isolation.
920   // In the future, mapping decisions may be impacted by the ability to
921   // isolate perfectly nested bands.
922   (void)tryIsolateBands(tileLoops);
923 
924   return tileLoops;
925 }
926