1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Utils/Utils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Support/MathExtras.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 
27 using namespace mlir;
28 
29 namespace {
30 // This structure is to pass and return sets of loop parameters without
31 // confusing the order.
32 struct LoopParams {
33   Value lowerBound;
34   Value upperBound;
35   Value step;
36 };
37 } // namespace
38 
39 scf::ForOp
40 mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
41                                ValueRange newIterOperands,
42                                const NewYieldValueFn &newYieldValuesFn) {
43   // Create a new loop before the existing one, with the extra operands.
44   OpBuilder::InsertionGuard g(builder);
45   builder.setInsertionPoint(loop);
46   auto operands = llvm::to_vector(loop.getIterOperands());
47   operands.append(newIterOperands.begin(), newIterOperands.end());
48   scf::ForOp newLoop = builder.create<scf::ForOp>(
49       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
50       operands, [](OpBuilder &, Location, Value, ValueRange) {});
51 
52   Block *loopBody = loop.getBody();
53   Block *newLoopBody = newLoop.getBody();
54 
55   // Move the body of the original loop to the new loop.
56   newLoopBody->getOperations().splice(newLoopBody->end(),
57                                       loopBody->getOperations());
58 
59   // Generate the new yield values to use by using the callback and append the
60   // yield values to the scf.yield operation.
61   auto yield = cast<scf::YieldOp>(newLoopBody->getTerminator());
62   ArrayRef<BlockArgument> newBBArgs =
63       newLoopBody->getArguments().take_back(newIterOperands.size());
64   {
65     OpBuilder::InsertionGuard g(builder);
66     builder.setInsertionPoint(yield);
67     SmallVector<Value> newYieldedValues =
68         newYieldValuesFn(builder, loop.getLoc(), newBBArgs);
69     assert(newIterOperands.size() == newYieldedValues.size() &&
70            "expected as many new yield values as new iter operands");
71     yield.getResultsMutable().append(newYieldedValues);
72   }
73 
74   // Remap the BlockArguments from the original loop to the new loop
75   // BlockArguments.
76   ArrayRef<BlockArgument> bbArgs = loopBody->getArguments();
77   for (auto it :
78        llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size())))
79     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
80 
81   // Replace all uses of `newIterOperands` with the corresponding basic block
82   // arguments.
83   for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
84     std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
85       Operation *user = use.getOwner();
86       return newLoop->isProperAncestor(user);
87     });
88   }
89 
90   // Replace all uses of the original loop with corresponding values from the
91   // new loop.
92   loop.replaceAllUsesWith(
93       newLoop.getResults().take_front(loop.getNumResults()));
94 
95   // Add a fake yield to the original loop body that just returns the
96   // BlockArguments corresponding to the iter_args. This makes it a no-op loop.
97   // The loop is dead. The caller is expected to erase it.
98   builder.setInsertionPointToEnd(loopBody);
99   builder.create<scf::YieldOp>(loop->getLoc(), loop.getRegionIterArgs());
100 
101   return newLoop;
102 }
103 
104 /// Outline a region with a single block into a new FuncOp.
105 /// Assumes the FuncOp result types is the type of the yielded operands of the
106 /// single block. This constraint makes it easy to determine the result.
107 /// This method also clones the `arith::ConstantIndexOp` at the start of
108 /// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
109 /// provided, it will be set to point to the operation that calls the outlined
110 /// function.
111 // TODO: support more than single-block regions.
112 // TODO: more flexible constant handling.
113 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
114                                                        Location loc,
115                                                        Region &region,
116                                                        StringRef funcName,
117                                                        func::CallOp *callOp) {
118   assert(!funcName.empty() && "funcName cannot be empty");
119   if (!region.hasOneBlock())
120     return failure();
121 
122   Block *originalBlock = &region.front();
123   Operation *originalTerminator = originalBlock->getTerminator();
124 
125   // Outline before current function.
126   OpBuilder::InsertionGuard g(rewriter);
127   rewriter.setInsertionPoint(region.getParentOfType<func::FuncOp>());
128 
129   SetVector<Value> captures;
130   getUsedValuesDefinedAbove(region, captures);
131 
132   ValueRange outlinedValues(captures.getArrayRef());
133   SmallVector<Type> outlinedFuncArgTypes;
134   SmallVector<Location> outlinedFuncArgLocs;
135   // Region's arguments are exactly the first block's arguments as per
136   // Region::getArguments().
137   // Func's arguments are cat(regions's arguments, captures arguments).
138   for (BlockArgument arg : region.getArguments()) {
139     outlinedFuncArgTypes.push_back(arg.getType());
140     outlinedFuncArgLocs.push_back(arg.getLoc());
141   }
142   for (Value value : outlinedValues) {
143     outlinedFuncArgTypes.push_back(value.getType());
144     outlinedFuncArgLocs.push_back(value.getLoc());
145   }
146   FunctionType outlinedFuncType =
147       FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
148                         originalTerminator->getOperandTypes());
149   auto outlinedFunc =
150       rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType);
151   Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
152 
153   // Merge blocks while replacing the original block operands.
154   // Warning: `mergeBlocks` erases the original block, reconstruct it later.
155   int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
156   auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
157   {
158     OpBuilder::InsertionGuard g(rewriter);
159     rewriter.setInsertionPointToEnd(outlinedFuncBody);
160     rewriter.mergeBlocks(
161         originalBlock, outlinedFuncBody,
162         outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
163     // Explicitly set up a new ReturnOp terminator.
164     rewriter.setInsertionPointToEnd(outlinedFuncBody);
165     rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(),
166                                     originalTerminator->getOperands());
167   }
168 
169   // Reconstruct the block that was deleted and add a
170   // terminator(call_results).
171   Block *newBlock = rewriter.createBlock(
172       &region, region.begin(),
173       TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
174       ArrayRef<Location>(outlinedFuncArgLocs)
175           .take_front(numOriginalBlockArguments));
176   {
177     OpBuilder::InsertionGuard g(rewriter);
178     rewriter.setInsertionPointToEnd(newBlock);
179     SmallVector<Value> callValues;
180     llvm::append_range(callValues, newBlock->getArguments());
181     llvm::append_range(callValues, outlinedValues);
182     auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
183     if (callOp)
184       *callOp = call;
185 
186     // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
187     // Clone `originalTerminator` to take the callOp results then erase it from
188     // `outlinedFuncBody`.
189     BlockAndValueMapping bvm;
190     bvm.map(originalTerminator->getOperands(), call->getResults());
191     rewriter.clone(*originalTerminator, bvm);
192     rewriter.eraseOp(originalTerminator);
193   }
194 
195   // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
196   // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
197   for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
198                                                outlinedValues.size()))) {
199     Value orig = std::get<0>(it);
200     Value repl = std::get<1>(it);
201     {
202       OpBuilder::InsertionGuard g(rewriter);
203       rewriter.setInsertionPointToStart(outlinedFuncBody);
204       if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
205         BlockAndValueMapping bvm;
206         repl = rewriter.clone(*cst, bvm)->getResult(0);
207       }
208     }
209     orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
210       return outlinedFunc->isProperAncestor(opOperand.getOwner());
211     });
212   }
213 
214   return outlinedFunc;
215 }
216 
217 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
218                                 func::FuncOp *thenFn, StringRef thenFnName,
219                                 func::FuncOp *elseFn, StringRef elseFnName) {
220   IRRewriter rewriter(b);
221   Location loc = ifOp.getLoc();
222   FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
223   if (thenFn && !ifOp.getThenRegion().empty()) {
224     outlinedFuncOpOrFailure = outlineSingleBlockRegion(
225         rewriter, loc, ifOp.getThenRegion(), thenFnName);
226     if (failed(outlinedFuncOpOrFailure))
227       return failure();
228     *thenFn = *outlinedFuncOpOrFailure;
229   }
230   if (elseFn && !ifOp.getElseRegion().empty()) {
231     outlinedFuncOpOrFailure = outlineSingleBlockRegion(
232         rewriter, loc, ifOp.getElseRegion(), elseFnName);
233     if (failed(outlinedFuncOpOrFailure))
234       return failure();
235     *elseFn = *outlinedFuncOpOrFailure;
236   }
237   return success();
238 }
239 
240 bool mlir::getInnermostParallelLoops(Operation *rootOp,
241                                      SmallVectorImpl<scf::ParallelOp> &result) {
242   assert(rootOp != nullptr && "Root operation must not be a nullptr.");
243   bool rootEnclosesPloops = false;
244   for (Region &region : rootOp->getRegions()) {
245     for (Block &block : region.getBlocks()) {
246       for (Operation &op : block) {
247         bool enclosesPloops = getInnermostParallelLoops(&op, result);
248         rootEnclosesPloops |= enclosesPloops;
249         if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
250           rootEnclosesPloops = true;
251 
252           // Collect parallel loop if it is an innermost one.
253           if (!enclosesPloops)
254             result.push_back(ploop);
255         }
256       }
257     }
258   }
259   return rootEnclosesPloops;
260 }
261 
262 // Build the IR that performs ceil division of a positive value by a constant:
263 //    ceildiv(a, B) = divis(a + (B-1), B)
264 // where divis is rounding-to-zero division.
265 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
266                              int64_t divisor) {
267   assert(divisor > 0 && "expected positive divisor");
268   assert(dividend.getType().isIndex() && "expected index-typed value");
269 
270   Value divisorMinusOneCst =
271       builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
272   Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
273   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
274   return builder.create<arith::DivSIOp>(loc, sum, divisorCst);
275 }
276 
277 // Build the IR that performs ceil division of a positive value by another
278 // positive value:
279 //    ceildiv(a, b) = divis(a + (b - 1), b)
280 // where divis is rounding-to-zero division.
281 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
282                              Value divisor) {
283   assert(dividend.getType().isIndex() && "expected index-typed value");
284 
285   Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
286   Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
287   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
288   return builder.create<arith::DivSIOp>(loc, sum, divisor);
289 }
290 
291 /// Helper to replace uses of loop carried values (iter_args) and loop
292 /// yield values while promoting single iteration scf.for ops.
293 static void replaceIterArgsAndYieldResults(scf::ForOp forOp) {
294   // Replace uses of iter arguments with iter operands (initial values).
295   auto iterOperands = forOp.getIterOperands();
296   auto iterArgs = forOp.getRegionIterArgs();
297   for (auto e : llvm::zip(iterOperands, iterArgs))
298     std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
299 
300   // Replace uses of loop results with the values yielded by the loop.
301   auto outerResults = forOp.getResults();
302   auto innerResults = forOp.getBody()->getTerminator()->getOperands();
303   for (auto e : llvm::zip(outerResults, innerResults))
304     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
305 }
306 
307 /// Promotes the loop body of a forOp to its containing block if the forOp
308 /// it can be determined that the loop has a single iteration.
309 LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
310   auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
311   auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
312   auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
313   if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 ||
314       ubCstOp.value() < 0 || stepCstOp.value() < 0)
315     return failure();
316   int64_t tripCount =
317       mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value());
318   if (tripCount != 1)
319     return failure();
320   auto iv = forOp.getInductionVar();
321   iv.replaceAllUsesWith(lbCstOp);
322 
323   replaceIterArgsAndYieldResults(forOp);
324 
325   // Move the loop body operations, except for its terminator, to the loop's
326   // containing block.
327   auto *parentBlock = forOp->getBlock();
328   forOp.getBody()->getTerminator()->erase();
329   parentBlock->getOperations().splice(Block::iterator(forOp),
330                                       forOp.getBody()->getOperations());
331   forOp.erase();
332   return success();
333 }
334 
335 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
336 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
337 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
338 /// unrolled iteration using annotateFn.
339 static void generateUnrolledLoop(
340     Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
341     function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
342     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
343     ValueRange iterArgs, ValueRange yieldedValues) {
344   // Builder to insert unrolled bodies just before the terminator of the body of
345   // 'forOp'.
346   auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
347 
348   if (!annotateFn)
349     annotateFn = [](unsigned, Operation *, OpBuilder) {};
350 
351   // Keep a pointer to the last non-terminator operation in the original block
352   // so that we know what to clone (since we are doing this in-place).
353   Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
354 
355   // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
356   SmallVector<Value, 4> lastYielded(yieldedValues);
357 
358   for (unsigned i = 1; i < unrollFactor; i++) {
359     BlockAndValueMapping operandMap;
360 
361     // Prepare operand map.
362     operandMap.map(iterArgs, lastYielded);
363 
364     // If the induction variable is used, create a remapping to the value for
365     // this unrolled instance.
366     if (!forOpIV.use_empty()) {
367       Value ivUnroll = ivRemapFn(i, forOpIV, builder);
368       operandMap.map(forOpIV, ivUnroll);
369     }
370 
371     // Clone the original body of 'forOp'.
372     for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
373       Operation *clonedOp = builder.clone(*it, operandMap);
374       annotateFn(i, clonedOp, builder);
375     }
376 
377     // Update yielded values.
378     for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
379       lastYielded[i] = operandMap.lookup(yieldedValues[i]);
380   }
381 
382   // Make sure we annotate the Ops in the original body. We do this last so that
383   // any annotations are not copied into the cloned Ops above.
384   for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
385     annotateFn(0, &*it, builder);
386 
387   // Update operands of the yield statement.
388   loopBodyBlock->getTerminator()->setOperands(lastYielded);
389 }
390 
391 /// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
392 LogicalResult mlir::loopUnrollByFactor(
393     scf::ForOp forOp, uint64_t unrollFactor,
394     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
395   assert(unrollFactor > 0 && "expected positive unroll factor");
396 
397   // Return if the loop body is empty.
398   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
399     return success();
400 
401   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
402   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
403   OpBuilder boundsBuilder(forOp);
404   auto loc = forOp.getLoc();
405   auto step = forOp.getStep();
406   Value upperBoundUnrolled;
407   Value stepUnrolled;
408   bool generateEpilogueLoop = true;
409 
410   auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
411   auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
412   auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
413   if (lbCstOp && ubCstOp && stepCstOp) {
414     // Constant loop bounds computation.
415     int64_t lbCst = lbCstOp.value();
416     int64_t ubCst = ubCstOp.value();
417     int64_t stepCst = stepCstOp.value();
418     assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
419            "expected positive loop bounds and step");
420     int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
421 
422     if (unrollFactor == 1) {
423       if (tripCount == 1 && failed(promoteIfSingleIteration(forOp)))
424         return failure();
425       return success();
426     }
427 
428     int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
429     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
430     assert(upperBoundUnrolledCst <= ubCst);
431     int64_t stepUnrolledCst = stepCst * unrollFactor;
432 
433     // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
434     generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
435     if (generateEpilogueLoop)
436       upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
437           loc, upperBoundUnrolledCst);
438     else
439       upperBoundUnrolled = ubCstOp;
440 
441     // Create constant for 'stepUnrolled'.
442     stepUnrolled = stepCst == stepUnrolledCst
443                        ? step
444                        : boundsBuilder.create<arith::ConstantIndexOp>(
445                              loc, stepUnrolledCst);
446   } else {
447     // Dynamic loop bounds computation.
448     // TODO: Add dynamic asserts for negative lb/ub/step, or
449     // consider using ceilDiv from AffineApplyExpander.
450     auto lowerBound = forOp.getLowerBound();
451     auto upperBound = forOp.getUpperBound();
452     Value diff =
453         boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
454     Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
455     Value unrollFactorCst =
456         boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
457     Value tripCountRem =
458         boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
459     // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
460     Value tripCountEvenMultiple =
461         boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem);
462     // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
463     upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
464         loc, lowerBound,
465         boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
466     // Scale 'step' by 'unrollFactor'.
467     stepUnrolled =
468         boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
469   }
470 
471   // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
472   if (generateEpilogueLoop) {
473     OpBuilder epilogueBuilder(forOp->getContext());
474     epilogueBuilder.setInsertionPoint(forOp->getBlock(),
475                                       std::next(Block::iterator(forOp)));
476     auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
477     epilogueForOp.setLowerBound(upperBoundUnrolled);
478 
479     // Update uses of loop results.
480     auto results = forOp.getResults();
481     auto epilogueResults = epilogueForOp.getResults();
482 
483     for (auto e : llvm::zip(results, epilogueResults)) {
484       std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
485     }
486     epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
487                                epilogueForOp.getNumIterOperands(), results);
488     (void)promoteIfSingleIteration(epilogueForOp);
489   }
490 
491   // Create unrolled loop.
492   forOp.setUpperBound(upperBoundUnrolled);
493   forOp.setStep(stepUnrolled);
494 
495   auto iterArgs = ValueRange(forOp.getRegionIterArgs());
496   auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
497 
498   generateUnrolledLoop(
499       forOp.getBody(), forOp.getInductionVar(), unrollFactor,
500       [&](unsigned i, Value iv, OpBuilder b) {
501         // iv' = iv + step * i;
502         auto stride = b.create<arith::MulIOp>(
503             loc, step, b.create<arith::ConstantIndexOp>(loc, i));
504         return b.create<arith::AddIOp>(loc, iv, stride);
505       },
506       annotateFn, iterArgs, yieldedValues);
507   // Promote the loop body up if this has turned into a single iteration loop.
508   (void)promoteIfSingleIteration(forOp);
509   return success();
510 }
511 
512 /// Return the new lower bound, upper bound, and step in that order. Insert any
513 /// additional bounds calculations before the given builder and any additional
514 /// conversion back to the original loop induction value inside the given Block.
515 static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
516                                 OpBuilder &insideLoopBuilder, Location loc,
517                                 Value lowerBound, Value upperBound, Value step,
518                                 Value inductionVar) {
519   // Check if the loop is already known to have a constant zero lower bound or
520   // a constant one step.
521   bool isZeroBased = false;
522   if (auto ubCst = lowerBound.getDefiningOp<arith::ConstantIndexOp>())
523     isZeroBased = ubCst.value() == 0;
524 
525   bool isStepOne = false;
526   if (auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>())
527     isStepOne = stepCst.value() == 1;
528 
529   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
530   // assuming the step is strictly positive.  Update the bounds and the step
531   // of the loop to go from 0 to the number of iterations, if necessary.
532   // TODO: introduce support for negative steps or emit dynamic asserts
533   // on step positivity, whatever gets implemented first.
534   if (isZeroBased && isStepOne)
535     return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
536             /*step=*/step};
537 
538   Value diff = boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
539   Value newUpperBound = ceilDivPositive(boundsBuilder, loc, diff, step);
540 
541   Value newLowerBound =
542       isZeroBased ? lowerBound
543                   : boundsBuilder.create<arith::ConstantIndexOp>(loc, 0);
544   Value newStep =
545       isStepOne ? step : boundsBuilder.create<arith::ConstantIndexOp>(loc, 1);
546 
547   // Insert code computing the value of the original loop induction variable
548   // from the "normalized" one.
549   Value scaled =
550       isStepOne
551           ? inductionVar
552           : insideLoopBuilder.create<arith::MulIOp>(loc, inductionVar, step);
553   Value shifted =
554       isZeroBased
555           ? scaled
556           : insideLoopBuilder.create<arith::AddIOp>(loc, scaled, lowerBound);
557 
558   SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
559                                        shifted.getDefiningOp()};
560   inductionVar.replaceAllUsesExcept(shifted, preserve);
561   return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
562           /*step=*/newStep};
563 }
564 
565 /// Transform a loop with a strictly positive step
566 ///   for %i = %lb to %ub step %s
567 /// into a 0-based loop with step 1
568 ///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
569 ///     %i = %ii * %s + %lb
570 /// Insert the induction variable remapping in the body of `inner`, which is
571 /// expected to be either `loop` or another loop perfectly nested under `loop`.
572 /// Insert the definition of new bounds immediate before `outer`, which is
573 /// expected to be either `loop` or its parent in the loop nest.
574 static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
575   OpBuilder builder(outer);
576   OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody());
577   auto loopPieces = normalizeLoop(builder, innerBuilder, loop.getLoc(),
578                                   loop.getLowerBound(), loop.getUpperBound(),
579                                   loop.getStep(), loop.getInductionVar());
580 
581   loop.setLowerBound(loopPieces.lowerBound);
582   loop.setUpperBound(loopPieces.upperBound);
583   loop.setStep(loopPieces.step);
584 }
585 
586 void mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
587   if (loops.size() < 2)
588     return;
589 
590   scf::ForOp innermost = loops.back();
591   scf::ForOp outermost = loops.front();
592 
593   // 1. Make sure all loops iterate from 0 to upperBound with step 1.  This
594   // allows the following code to assume upperBound is the number of iterations.
595   for (auto loop : loops)
596     normalizeLoop(loop, outermost, innermost);
597 
598   // 2. Emit code computing the upper bound of the coalesced loop as product
599   // of the number of iterations of all loops.
600   OpBuilder builder(outermost);
601   Location loc = outermost.getLoc();
602   Value upperBound = outermost.getUpperBound();
603   for (auto loop : loops.drop_front())
604     upperBound =
605         builder.create<arith::MulIOp>(loc, upperBound, loop.getUpperBound());
606   outermost.setUpperBound(upperBound);
607 
608   builder.setInsertionPointToStart(outermost.getBody());
609 
610   // 3. Remap induction variables. For each original loop, the value of the
611   // induction variable can be obtained by dividing the induction variable of
612   // the linearized loop by the total number of iterations of the loops nested
613   // in it modulo the number of iterations in this loop (remove the values
614   // related to the outer loops):
615   //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
616   // Compute these iteratively from the innermost loop by creating a "running
617   // quotient" of division by the range.
618   Value previous = outermost.getInductionVar();
619   for (unsigned i = 0, e = loops.size(); i < e; ++i) {
620     unsigned idx = loops.size() - i - 1;
621     if (i != 0)
622       previous = builder.create<arith::DivSIOp>(loc, previous,
623                                                 loops[idx + 1].getUpperBound());
624 
625     Value iv = (i == e - 1) ? previous
626                             : builder.create<arith::RemSIOp>(
627                                   loc, previous, loops[idx].getUpperBound());
628     replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
629                                loops.back().getRegion());
630   }
631 
632   // 4. Move the operations from the innermost just above the second-outermost
633   // loop, delete the extra terminator and the second-outermost loop.
634   scf::ForOp second = loops[1];
635   innermost.getBody()->back().erase();
636   outermost.getBody()->getOperations().splice(
637       Block::iterator(second.getOperation()),
638       innermost.getBody()->getOperations());
639   second.erase();
640 }
641 
642 void mlir::collapseParallelLoops(
643     scf::ParallelOp loops, ArrayRef<std::vector<unsigned>> combinedDimensions) {
644   OpBuilder outsideBuilder(loops);
645   Location loc = loops.getLoc();
646 
647   // Presort combined dimensions.
648   auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
649   for (auto &dims : sortedDimensions)
650     std::sort(dims.begin(), dims.end());
651 
652   // Normalize ParallelOp's iteration pattern.
653   SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
654       normalizedUpperBounds;
655   for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
656     OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody());
657     auto resultBounds =
658         normalizeLoop(outsideBuilder, insideLoopBuilder, loc,
659                       loops.getLowerBound()[i], loops.getUpperBound()[i],
660                       loops.getStep()[i], loops.getBody()->getArgument(i));
661 
662     normalizedLowerBounds.push_back(resultBounds.lowerBound);
663     normalizedUpperBounds.push_back(resultBounds.upperBound);
664     normalizedSteps.push_back(resultBounds.step);
665   }
666 
667   // Combine iteration spaces.
668   SmallVector<Value, 3> lowerBounds, upperBounds, steps;
669   auto cst0 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 0);
670   auto cst1 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
671   for (unsigned i = 0, e = sortedDimensions.size(); i < e; ++i) {
672     Value newUpperBound = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
673     for (auto idx : sortedDimensions[i]) {
674       newUpperBound = outsideBuilder.create<arith::MulIOp>(
675           loc, newUpperBound, normalizedUpperBounds[idx]);
676     }
677     lowerBounds.push_back(cst0);
678     steps.push_back(cst1);
679     upperBounds.push_back(newUpperBound);
680   }
681 
682   // Create new ParallelLoop with conversions to the original induction values.
683   // The loop below uses divisions to get the relevant range of values in the
684   // new induction value that represent each range of the original induction
685   // value. The remainders then determine based on that range, which iteration
686   // of the original induction value this represents. This is a normalized value
687   // that is un-normalized already by the previous logic.
688   auto newPloop = outsideBuilder.create<scf::ParallelOp>(
689       loc, lowerBounds, upperBounds, steps,
690       [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
691         for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
692           Value previous = ploopIVs[i];
693           unsigned numberCombinedDimensions = combinedDimensions[i].size();
694           // Iterate over all except the last induction value.
695           for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
696             unsigned idx = combinedDimensions[i][j];
697 
698             // Determine the current induction value's current loop iteration
699             Value iv = insideBuilder.create<arith::RemSIOp>(
700                 loc, previous, normalizedUpperBounds[idx]);
701             replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
702                                        loops.getRegion());
703 
704             // Remove the effect of the current induction value to prepare for
705             // the next value.
706             previous = insideBuilder.create<arith::DivSIOp>(
707                 loc, previous, normalizedUpperBounds[idx]);
708           }
709 
710           // The final induction value is just the remaining value.
711           unsigned idx = combinedDimensions[i][0];
712           replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
713                                      previous, loops.getRegion());
714         }
715       });
716 
717   // Replace the old loop with the new loop.
718   loops.getBody()->back().erase();
719   newPloop.getBody()->getOperations().splice(
720       Block::iterator(newPloop.getBody()->back()),
721       loops.getBody()->getOperations());
722   loops.erase();
723 }
724 
725 // Hoist the ops within `outer` that appear before `inner`.
726 // Such ops include the ops that have been introduced by parametric tiling.
727 // Ops that come from triangular loops (i.e. that belong to the program slice
728 // rooted at `outer`) and ops that have side effects cannot be hoisted.
729 // Return failure when any op fails to hoist.
730 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
731   SetVector<Operation *> forwardSlice;
732   getForwardSlice(
733       outer.getInductionVar(), &forwardSlice,
734       [&inner](Operation *op) { return op != inner.getOperation(); });
735   LogicalResult status = success();
736   SmallVector<Operation *, 8> toHoist;
737   for (auto &op : outer.getBody()->without_terminator()) {
738     // Stop when encountering the inner loop.
739     if (&op == inner.getOperation())
740       break;
741     // Skip over non-hoistable ops.
742     if (forwardSlice.count(&op) > 0) {
743       status = failure();
744       continue;
745     }
746     // Skip intermediate scf::ForOp, these are not considered a failure.
747     if (isa<scf::ForOp>(op))
748       continue;
749     // Skip other ops with regions.
750     if (op.getNumRegions() > 0) {
751       status = failure();
752       continue;
753     }
754     // Skip if op has side effects.
755     // TODO: loads to immutable memory regions are ok.
756     if (!MemoryEffectOpInterface::hasNoEffect(&op)) {
757       status = failure();
758       continue;
759     }
760     toHoist.push_back(&op);
761   }
762   auto *outerForOp = outer.getOperation();
763   for (auto *op : toHoist)
764     op->moveBefore(outerForOp);
765   return status;
766 }
767 
768 // Traverse the interTile and intraTile loops and try to hoist ops such that
769 // bands of perfectly nested loops are isolated.
770 // Return failure if either perfect interTile or perfect intraTile bands cannot
771 // be formed.
772 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
773   LogicalResult status = success();
774   const Loops &interTile = tileLoops.first;
775   const Loops &intraTile = tileLoops.second;
776   auto size = interTile.size();
777   assert(size == intraTile.size());
778   if (size <= 1)
779     return success();
780   for (unsigned s = 1; s < size; ++s)
781     status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
782                                : failure();
783   for (unsigned s = 1; s < size; ++s)
784     status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
785                                : failure();
786   return status;
787 }
788 
789 /// Collect perfectly nested loops starting from `rootForOps`.  Loops are
790 /// perfectly nested if each loop is the first and only non-terminator operation
791 /// in the parent loop.  Collect at most `maxLoops` loops and append them to
792 /// `forOps`.
793 template <typename T>
794 static void getPerfectlyNestedLoopsImpl(
795     SmallVectorImpl<T> &forOps, T rootForOp,
796     unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
797   for (unsigned i = 0; i < maxLoops; ++i) {
798     forOps.push_back(rootForOp);
799     Block &body = rootForOp.getRegion().front();
800     if (body.begin() != std::prev(body.end(), 2))
801       return;
802 
803     rootForOp = dyn_cast<T>(&body.front());
804     if (!rootForOp)
805       return;
806   }
807 }
808 
809 static Loops stripmineSink(scf::ForOp forOp, Value factor,
810                            ArrayRef<scf::ForOp> targets) {
811   auto originalStep = forOp.getStep();
812   auto iv = forOp.getInductionVar();
813 
814   OpBuilder b(forOp);
815   forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
816 
817   Loops innerLoops;
818   for (auto t : targets) {
819     // Save information for splicing ops out of t when done
820     auto begin = t.getBody()->begin();
821     auto nOps = t.getBody()->getOperations().size();
822 
823     // Insert newForOp before the terminator of `t`.
824     auto b = OpBuilder::atBlockTerminator((t.getBody()));
825     Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
826     Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
827                                          forOp.getUpperBound(), stepped);
828     Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
829                                          forOp.getUpperBound(), stepped);
830 
831     // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
832     auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
833     newForOp.getBody()->getOperations().splice(
834         newForOp.getBody()->getOperations().begin(),
835         t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
836     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
837                                newForOp.getRegion());
838 
839     innerLoops.push_back(newForOp);
840   }
841 
842   return innerLoops;
843 }
844 
845 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
846 // Returns the new for operation, nested immediately under `target`.
847 template <typename SizeType>
848 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
849                                 scf::ForOp target) {
850   // TODO: Use cheap structural assertions that targets are nested under
851   // forOp and that targets are not nested under each other when DominanceInfo
852   // exposes the capability. It seems overkill to construct a whole function
853   // dominance tree at this point.
854   auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
855   assert(res.size() == 1 && "Expected 1 inner forOp");
856   return res[0];
857 }
858 
859 SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
860                                  ArrayRef<Value> sizes,
861                                  ArrayRef<scf::ForOp> targets) {
862   SmallVector<SmallVector<scf::ForOp, 8>, 8> res;
863   SmallVector<scf::ForOp, 8> currentTargets(targets.begin(), targets.end());
864   for (auto it : llvm::zip(forOps, sizes)) {
865     auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
866     res.push_back(step);
867     currentTargets = step;
868   }
869   return res;
870 }
871 
872 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
873                  scf::ForOp target) {
874   SmallVector<scf::ForOp, 8> res;
875   for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
876     assert(loops.size() == 1);
877     res.push_back(loops[0]);
878   }
879   return res;
880 }
881 
882 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
883   // Collect perfectly nested loops.  If more size values provided than nested
884   // loops available, truncate `sizes`.
885   SmallVector<scf::ForOp, 4> forOps;
886   forOps.reserve(sizes.size());
887   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
888   if (forOps.size() < sizes.size())
889     sizes = sizes.take_front(forOps.size());
890 
891   return ::tile(forOps, sizes, forOps.back());
892 }
893 
894 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
895                                    scf::ForOp root) {
896   getPerfectlyNestedLoopsImpl(nestedLoops, root);
897 }
898 
899 TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
900                                        ArrayRef<int64_t> sizes) {
901   // Collect perfectly nested loops.  If more size values provided than nested
902   // loops available, truncate `sizes`.
903   SmallVector<scf::ForOp, 4> forOps;
904   forOps.reserve(sizes.size());
905   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
906   if (forOps.size() < sizes.size())
907     sizes = sizes.take_front(forOps.size());
908 
909   // Compute the tile sizes such that i-th outer loop executes size[i]
910   // iterations.  Given that the loop current executes
911   //   numIterations = ceildiv((upperBound - lowerBound), step)
912   // iterations, we need to tile with size ceildiv(numIterations, size[i]).
913   SmallVector<Value, 4> tileSizes;
914   tileSizes.reserve(sizes.size());
915   for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
916     assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
917 
918     auto forOp = forOps[i];
919     OpBuilder builder(forOp);
920     auto loc = forOp.getLoc();
921     Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(),
922                                                forOp.getLowerBound());
923     Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
924     Value iterationsPerBlock =
925         ceilDivPositive(builder, loc, numIterations, sizes[i]);
926     tileSizes.push_back(iterationsPerBlock);
927   }
928 
929   // Call parametric tiling with the given sizes.
930   auto intraTile = tile(forOps, tileSizes, forOps.back());
931   TileLoops tileLoops = std::make_pair(forOps, intraTile);
932 
933   // TODO: for now we just ignore the result of band isolation.
934   // In the future, mapping decisions may be impacted by the ability to
935   // isolate perfectly nested bands.
936   (void)tryIsolateBands(tileLoops);
937 
938   return tileLoops;
939 }
940