1 //===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/Utils/Utils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Support/MathExtras.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 
27 using namespace mlir;
28 
29 namespace {
30 // This structure is to pass and return sets of loop parameters without
31 // confusing the order.
32 struct LoopParams {
33   Value lowerBound;
34   Value upperBound;
35   Value step;
36 };
37 } // namespace
38 
39 scf::ForOp mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
40                                           ValueRange newIterOperands,
41                                           NewYieldValueFn newYieldValuesFn) {
42   // Create a new loop before the existing one, with the extra operands.
43   OpBuilder::InsertionGuard g(builder);
44   builder.setInsertionPoint(loop);
45   auto operands = llvm::to_vector(loop.getIterOperands());
46   operands.append(newIterOperands.begin(), newIterOperands.end());
47   scf::ForOp newLoop = builder.create<scf::ForOp>(
48       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
49       operands, [](OpBuilder &, Location, Value, ValueRange) {});
50 
51   Block *loopBody = loop.getBody();
52   Block *newLoopBody = newLoop.getBody();
53 
54   // Move the body of the original loop to the new loop.
55   newLoopBody->getOperations().splice(newLoopBody->end(),
56                                       loopBody->getOperations());
57 
58   // Generate the new yield values to use by using the callback and append the
59   // yield values to the scf.yield operation.
60   auto yield = cast<scf::YieldOp>(newLoopBody->getTerminator());
61   ArrayRef<BlockArgument> newBBArgs =
62       newLoopBody->getArguments().take_back(newIterOperands.size());
63   {
64     OpBuilder::InsertionGuard g(builder);
65     builder.setInsertionPoint(yield);
66     SmallVector<Value> newYieldedValues =
67         newYieldValuesFn(builder, loop.getLoc(), newBBArgs);
68     assert(newIterOperands.size() == newYieldedValues.size() &&
69            "expected as many new yield values as new iter operands");
70     yield.getResultsMutable().append(newYieldedValues);
71   }
72 
73   // Remap the BlockArguments from the original loop to the new loop
74   // BlockArguments.
75   ArrayRef<BlockArgument> bbArgs = loopBody->getArguments();
76   for (auto it :
77        llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size())))
78     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
79 
80   // Replace all uses of `newIterOperands` with the corresponding basic block
81   // arguments.
82   for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
83     std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
84       Operation *user = use.getOwner();
85       return newLoop->isProperAncestor(user);
86     });
87   }
88 
89   // Replace all uses of the original loop with corresponding values from the
90   // new loop.
91   loop.replaceAllUsesWith(
92       newLoop.getResults().take_front(loop.getNumResults()));
93 
94   // Add a fake yield to the original loop body that just returns the
95   // BlockArguments corresponding to the iter_args. This makes it a no-op loop.
96   // The loop is dead. The caller is expected to erase it.
97   builder.setInsertionPointToEnd(loopBody);
98   builder.create<scf::YieldOp>(loop->getLoc(), loop.getRegionIterArgs());
99 
100   return newLoop;
101 }
102 
103 /// Outline a region with a single block into a new FuncOp.
104 /// Assumes the FuncOp result types is the type of the yielded operands of the
105 /// single block. This constraint makes it easy to determine the result.
106 /// This method also clones the `arith::ConstantIndexOp` at the start of
107 /// `outlinedFuncBody` to alloc simple canonicalizations.
108 // TODO: support more than single-block regions.
109 // TODO: more flexible constant handling.
110 FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
111                                                        Location loc,
112                                                        Region &region,
113                                                        StringRef funcName) {
114   assert(!funcName.empty() && "funcName cannot be empty");
115   if (!region.hasOneBlock())
116     return failure();
117 
118   Block *originalBlock = &region.front();
119   Operation *originalTerminator = originalBlock->getTerminator();
120 
121   // Outline before current function.
122   OpBuilder::InsertionGuard g(rewriter);
123   rewriter.setInsertionPoint(region.getParentOfType<func::FuncOp>());
124 
125   SetVector<Value> captures;
126   getUsedValuesDefinedAbove(region, captures);
127 
128   ValueRange outlinedValues(captures.getArrayRef());
129   SmallVector<Type> outlinedFuncArgTypes;
130   SmallVector<Location> outlinedFuncArgLocs;
131   // Region's arguments are exactly the first block's arguments as per
132   // Region::getArguments().
133   // Func's arguments are cat(regions's arguments, captures arguments).
134   for (BlockArgument arg : region.getArguments()) {
135     outlinedFuncArgTypes.push_back(arg.getType());
136     outlinedFuncArgLocs.push_back(arg.getLoc());
137   }
138   for (Value value : outlinedValues) {
139     outlinedFuncArgTypes.push_back(value.getType());
140     outlinedFuncArgLocs.push_back(value.getLoc());
141   }
142   FunctionType outlinedFuncType =
143       FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
144                         originalTerminator->getOperandTypes());
145   auto outlinedFunc =
146       rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType);
147   Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
148 
149   // Merge blocks while replacing the original block operands.
150   // Warning: `mergeBlocks` erases the original block, reconstruct it later.
151   int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
152   auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
153   {
154     OpBuilder::InsertionGuard g(rewriter);
155     rewriter.setInsertionPointToEnd(outlinedFuncBody);
156     rewriter.mergeBlocks(
157         originalBlock, outlinedFuncBody,
158         outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
159     // Explicitly set up a new ReturnOp terminator.
160     rewriter.setInsertionPointToEnd(outlinedFuncBody);
161     rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(),
162                                     originalTerminator->getOperands());
163   }
164 
165   // Reconstruct the block that was deleted and add a
166   // terminator(call_results).
167   Block *newBlock = rewriter.createBlock(
168       &region, region.begin(),
169       TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
170       ArrayRef<Location>(outlinedFuncArgLocs)
171           .take_front(numOriginalBlockArguments));
172   {
173     OpBuilder::InsertionGuard g(rewriter);
174     rewriter.setInsertionPointToEnd(newBlock);
175     SmallVector<Value> callValues;
176     llvm::append_range(callValues, newBlock->getArguments());
177     llvm::append_range(callValues, outlinedValues);
178     Operation *call =
179         rewriter.create<func::CallOp>(loc, outlinedFunc, callValues);
180 
181     // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
182     // Clone `originalTerminator` to take the callOp results then erase it from
183     // `outlinedFuncBody`.
184     BlockAndValueMapping bvm;
185     bvm.map(originalTerminator->getOperands(), call->getResults());
186     rewriter.clone(*originalTerminator, bvm);
187     rewriter.eraseOp(originalTerminator);
188   }
189 
190   // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
191   // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
192   for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
193                                                outlinedValues.size()))) {
194     Value orig = std::get<0>(it);
195     Value repl = std::get<1>(it);
196     {
197       OpBuilder::InsertionGuard g(rewriter);
198       rewriter.setInsertionPointToStart(outlinedFuncBody);
199       if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
200         BlockAndValueMapping bvm;
201         repl = rewriter.clone(*cst, bvm)->getResult(0);
202       }
203     }
204     orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
205       return outlinedFunc->isProperAncestor(opOperand.getOwner());
206     });
207   }
208 
209   return outlinedFunc;
210 }
211 
212 LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
213                                 func::FuncOp *thenFn, StringRef thenFnName,
214                                 func::FuncOp *elseFn, StringRef elseFnName) {
215   IRRewriter rewriter(b);
216   Location loc = ifOp.getLoc();
217   FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
218   if (thenFn && !ifOp.getThenRegion().empty()) {
219     outlinedFuncOpOrFailure = outlineSingleBlockRegion(
220         rewriter, loc, ifOp.getThenRegion(), thenFnName);
221     if (failed(outlinedFuncOpOrFailure))
222       return failure();
223     *thenFn = *outlinedFuncOpOrFailure;
224   }
225   if (elseFn && !ifOp.getElseRegion().empty()) {
226     outlinedFuncOpOrFailure = outlineSingleBlockRegion(
227         rewriter, loc, ifOp.getElseRegion(), elseFnName);
228     if (failed(outlinedFuncOpOrFailure))
229       return failure();
230     *elseFn = *outlinedFuncOpOrFailure;
231   }
232   return success();
233 }
234 
235 bool mlir::getInnermostParallelLoops(Operation *rootOp,
236                                      SmallVectorImpl<scf::ParallelOp> &result) {
237   assert(rootOp != nullptr && "Root operation must not be a nullptr.");
238   bool rootEnclosesPloops = false;
239   for (Region &region : rootOp->getRegions()) {
240     for (Block &block : region.getBlocks()) {
241       for (Operation &op : block) {
242         bool enclosesPloops = getInnermostParallelLoops(&op, result);
243         rootEnclosesPloops |= enclosesPloops;
244         if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
245           rootEnclosesPloops = true;
246 
247           // Collect parallel loop if it is an innermost one.
248           if (!enclosesPloops)
249             result.push_back(ploop);
250         }
251       }
252     }
253   }
254   return rootEnclosesPloops;
255 }
256 
257 // Build the IR that performs ceil division of a positive value by a constant:
258 //    ceildiv(a, B) = divis(a + (B-1), B)
259 // where divis is rounding-to-zero division.
260 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
261                              int64_t divisor) {
262   assert(divisor > 0 && "expected positive divisor");
263   assert(dividend.getType().isIndex() && "expected index-typed value");
264 
265   Value divisorMinusOneCst =
266       builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
267   Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
268   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
269   return builder.create<arith::DivSIOp>(loc, sum, divisorCst);
270 }
271 
272 // Build the IR that performs ceil division of a positive value by another
273 // positive value:
274 //    ceildiv(a, b) = divis(a + (b - 1), b)
275 // where divis is rounding-to-zero division.
276 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
277                              Value divisor) {
278   assert(dividend.getType().isIndex() && "expected index-typed value");
279 
280   Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
281   Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
282   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
283   return builder.create<arith::DivSIOp>(loc, sum, divisor);
284 }
285 
286 /// Helper to replace uses of loop carried values (iter_args) and loop
287 /// yield values while promoting single iteration scf.for ops.
288 static void replaceIterArgsAndYieldResults(scf::ForOp forOp) {
289   // Replace uses of iter arguments with iter operands (initial values).
290   auto iterOperands = forOp.getIterOperands();
291   auto iterArgs = forOp.getRegionIterArgs();
292   for (auto e : llvm::zip(iterOperands, iterArgs))
293     std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
294 
295   // Replace uses of loop results with the values yielded by the loop.
296   auto outerResults = forOp.getResults();
297   auto innerResults = forOp.getBody()->getTerminator()->getOperands();
298   for (auto e : llvm::zip(outerResults, innerResults))
299     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
300 }
301 
302 /// Promotes the loop body of a forOp to its containing block if the forOp
303 /// it can be determined that the loop has a single iteration.
304 LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
305   auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
306   auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
307   auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
308   if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 ||
309       ubCstOp.value() < 0 || stepCstOp.value() < 0)
310     return failure();
311   int64_t tripCount =
312       mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value());
313   if (tripCount != 1)
314     return failure();
315   auto iv = forOp.getInductionVar();
316   iv.replaceAllUsesWith(lbCstOp);
317 
318   replaceIterArgsAndYieldResults(forOp);
319 
320   // Move the loop body operations, except for its terminator, to the loop's
321   // containing block.
322   auto *parentBlock = forOp->getBlock();
323   forOp.getBody()->getTerminator()->erase();
324   parentBlock->getOperations().splice(Block::iterator(forOp),
325                                       forOp.getBody()->getOperations());
326   forOp.erase();
327   return success();
328 }
329 
330 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
331 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
332 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
333 /// unrolled iteration using annotateFn.
334 static void generateUnrolledLoop(
335     Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
336     function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
337     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
338     ValueRange iterArgs, ValueRange yieldedValues) {
339   // Builder to insert unrolled bodies just before the terminator of the body of
340   // 'forOp'.
341   auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
342 
343   if (!annotateFn)
344     annotateFn = [](unsigned, Operation *, OpBuilder) {};
345 
346   // Keep a pointer to the last non-terminator operation in the original block
347   // so that we know what to clone (since we are doing this in-place).
348   Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
349 
350   // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
351   SmallVector<Value, 4> lastYielded(yieldedValues);
352 
353   for (unsigned i = 1; i < unrollFactor; i++) {
354     BlockAndValueMapping operandMap;
355 
356     // Prepare operand map.
357     operandMap.map(iterArgs, lastYielded);
358 
359     // If the induction variable is used, create a remapping to the value for
360     // this unrolled instance.
361     if (!forOpIV.use_empty()) {
362       Value ivUnroll = ivRemapFn(i, forOpIV, builder);
363       operandMap.map(forOpIV, ivUnroll);
364     }
365 
366     // Clone the original body of 'forOp'.
367     for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
368       Operation *clonedOp = builder.clone(*it, operandMap);
369       annotateFn(i, clonedOp, builder);
370     }
371 
372     // Update yielded values.
373     for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
374       lastYielded[i] = operandMap.lookup(yieldedValues[i]);
375   }
376 
377   // Make sure we annotate the Ops in the original body. We do this last so that
378   // any annotations are not copied into the cloned Ops above.
379   for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
380     annotateFn(0, &*it, builder);
381 
382   // Update operands of the yield statement.
383   loopBodyBlock->getTerminator()->setOperands(lastYielded);
384 }
385 
386 /// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
387 LogicalResult mlir::loopUnrollByFactor(
388     scf::ForOp forOp, uint64_t unrollFactor,
389     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
390   assert(unrollFactor > 0 && "expected positive unroll factor");
391 
392   // Return if the loop body is empty.
393   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
394     return success();
395 
396   // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
397   // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
398   OpBuilder boundsBuilder(forOp);
399   auto loc = forOp.getLoc();
400   auto step = forOp.getStep();
401   Value upperBoundUnrolled;
402   Value stepUnrolled;
403   bool generateEpilogueLoop = true;
404 
405   auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
406   auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
407   auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
408   if (lbCstOp && ubCstOp && stepCstOp) {
409     // Constant loop bounds computation.
410     int64_t lbCst = lbCstOp.value();
411     int64_t ubCst = ubCstOp.value();
412     int64_t stepCst = stepCstOp.value();
413     assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
414            "expected positive loop bounds and step");
415     int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
416 
417     if (unrollFactor == 1) {
418       if (tripCount == 1 && failed(promoteIfSingleIteration(forOp)))
419         return failure();
420       return success();
421     }
422 
423     int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
424     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
425     assert(upperBoundUnrolledCst <= ubCst);
426     int64_t stepUnrolledCst = stepCst * unrollFactor;
427 
428     // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
429     generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
430     if (generateEpilogueLoop)
431       upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
432           loc, upperBoundUnrolledCst);
433     else
434       upperBoundUnrolled = ubCstOp;
435 
436     // Create constant for 'stepUnrolled'.
437     stepUnrolled = stepCst == stepUnrolledCst
438                        ? step
439                        : boundsBuilder.create<arith::ConstantIndexOp>(
440                              loc, stepUnrolledCst);
441   } else {
442     // Dynamic loop bounds computation.
443     // TODO: Add dynamic asserts for negative lb/ub/step, or
444     // consider using ceilDiv from AffineApplyExpander.
445     auto lowerBound = forOp.getLowerBound();
446     auto upperBound = forOp.getUpperBound();
447     Value diff =
448         boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
449     Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
450     Value unrollFactorCst =
451         boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
452     Value tripCountRem =
453         boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
454     // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
455     Value tripCountEvenMultiple =
456         boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem);
457     // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
458     upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>(
459         loc, lowerBound,
460         boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
461     // Scale 'step' by 'unrollFactor'.
462     stepUnrolled =
463         boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
464   }
465 
466   // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
467   if (generateEpilogueLoop) {
468     OpBuilder epilogueBuilder(forOp->getContext());
469     epilogueBuilder.setInsertionPoint(forOp->getBlock(),
470                                       std::next(Block::iterator(forOp)));
471     auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
472     epilogueForOp.setLowerBound(upperBoundUnrolled);
473 
474     // Update uses of loop results.
475     auto results = forOp.getResults();
476     auto epilogueResults = epilogueForOp.getResults();
477 
478     for (auto e : llvm::zip(results, epilogueResults)) {
479       std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
480     }
481     epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
482                                epilogueForOp.getNumIterOperands(), results);
483     (void)promoteIfSingleIteration(epilogueForOp);
484   }
485 
486   // Create unrolled loop.
487   forOp.setUpperBound(upperBoundUnrolled);
488   forOp.setStep(stepUnrolled);
489 
490   auto iterArgs = ValueRange(forOp.getRegionIterArgs());
491   auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
492 
493   generateUnrolledLoop(
494       forOp.getBody(), forOp.getInductionVar(), unrollFactor,
495       [&](unsigned i, Value iv, OpBuilder b) {
496         // iv' = iv + step * i;
497         auto stride = b.create<arith::MulIOp>(
498             loc, step, b.create<arith::ConstantIndexOp>(loc, i));
499         return b.create<arith::AddIOp>(loc, iv, stride);
500       },
501       annotateFn, iterArgs, yieldedValues);
502   // Promote the loop body up if this has turned into a single iteration loop.
503   (void)promoteIfSingleIteration(forOp);
504   return success();
505 }
506 
507 /// Return the new lower bound, upper bound, and step in that order. Insert any
508 /// additional bounds calculations before the given builder and any additional
509 /// conversion back to the original loop induction value inside the given Block.
510 static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
511                                 OpBuilder &insideLoopBuilder, Location loc,
512                                 Value lowerBound, Value upperBound, Value step,
513                                 Value inductionVar) {
514   // Check if the loop is already known to have a constant zero lower bound or
515   // a constant one step.
516   bool isZeroBased = false;
517   if (auto ubCst = lowerBound.getDefiningOp<arith::ConstantIndexOp>())
518     isZeroBased = ubCst.value() == 0;
519 
520   bool isStepOne = false;
521   if (auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>())
522     isStepOne = stepCst.value() == 1;
523 
524   // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
525   // assuming the step is strictly positive.  Update the bounds and the step
526   // of the loop to go from 0 to the number of iterations, if necessary.
527   // TODO: introduce support for negative steps or emit dynamic asserts
528   // on step positivity, whatever gets implemented first.
529   if (isZeroBased && isStepOne)
530     return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
531             /*step=*/step};
532 
533   Value diff = boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
534   Value newUpperBound = ceilDivPositive(boundsBuilder, loc, diff, step);
535 
536   Value newLowerBound =
537       isZeroBased ? lowerBound
538                   : boundsBuilder.create<arith::ConstantIndexOp>(loc, 0);
539   Value newStep =
540       isStepOne ? step : boundsBuilder.create<arith::ConstantIndexOp>(loc, 1);
541 
542   // Insert code computing the value of the original loop induction variable
543   // from the "normalized" one.
544   Value scaled =
545       isStepOne
546           ? inductionVar
547           : insideLoopBuilder.create<arith::MulIOp>(loc, inductionVar, step);
548   Value shifted =
549       isZeroBased
550           ? scaled
551           : insideLoopBuilder.create<arith::AddIOp>(loc, scaled, lowerBound);
552 
553   SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
554                                        shifted.getDefiningOp()};
555   inductionVar.replaceAllUsesExcept(shifted, preserve);
556   return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
557           /*step=*/newStep};
558 }
559 
560 /// Transform a loop with a strictly positive step
561 ///   for %i = %lb to %ub step %s
562 /// into a 0-based loop with step 1
563 ///   for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
564 ///     %i = %ii * %s + %lb
565 /// Insert the induction variable remapping in the body of `inner`, which is
566 /// expected to be either `loop` or another loop perfectly nested under `loop`.
567 /// Insert the definition of new bounds immediate before `outer`, which is
568 /// expected to be either `loop` or its parent in the loop nest.
569 static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
570   OpBuilder builder(outer);
571   OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody());
572   auto loopPieces = normalizeLoop(builder, innerBuilder, loop.getLoc(),
573                                   loop.getLowerBound(), loop.getUpperBound(),
574                                   loop.getStep(), loop.getInductionVar());
575 
576   loop.setLowerBound(loopPieces.lowerBound);
577   loop.setUpperBound(loopPieces.upperBound);
578   loop.setStep(loopPieces.step);
579 }
580 
581 void mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
582   if (loops.size() < 2)
583     return;
584 
585   scf::ForOp innermost = loops.back();
586   scf::ForOp outermost = loops.front();
587 
588   // 1. Make sure all loops iterate from 0 to upperBound with step 1.  This
589   // allows the following code to assume upperBound is the number of iterations.
590   for (auto loop : loops)
591     normalizeLoop(loop, outermost, innermost);
592 
593   // 2. Emit code computing the upper bound of the coalesced loop as product
594   // of the number of iterations of all loops.
595   OpBuilder builder(outermost);
596   Location loc = outermost.getLoc();
597   Value upperBound = outermost.getUpperBound();
598   for (auto loop : loops.drop_front())
599     upperBound =
600         builder.create<arith::MulIOp>(loc, upperBound, loop.getUpperBound());
601   outermost.setUpperBound(upperBound);
602 
603   builder.setInsertionPointToStart(outermost.getBody());
604 
605   // 3. Remap induction variables. For each original loop, the value of the
606   // induction variable can be obtained by dividing the induction variable of
607   // the linearized loop by the total number of iterations of the loops nested
608   // in it modulo the number of iterations in this loop (remove the values
609   // related to the outer loops):
610   //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
611   // Compute these iteratively from the innermost loop by creating a "running
612   // quotient" of division by the range.
613   Value previous = outermost.getInductionVar();
614   for (unsigned i = 0, e = loops.size(); i < e; ++i) {
615     unsigned idx = loops.size() - i - 1;
616     if (i != 0)
617       previous = builder.create<arith::DivSIOp>(loc, previous,
618                                                 loops[idx + 1].getUpperBound());
619 
620     Value iv = (i == e - 1) ? previous
621                             : builder.create<arith::RemSIOp>(
622                                   loc, previous, loops[idx].getUpperBound());
623     replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
624                                loops.back().getRegion());
625   }
626 
627   // 4. Move the operations from the innermost just above the second-outermost
628   // loop, delete the extra terminator and the second-outermost loop.
629   scf::ForOp second = loops[1];
630   innermost.getBody()->back().erase();
631   outermost.getBody()->getOperations().splice(
632       Block::iterator(second.getOperation()),
633       innermost.getBody()->getOperations());
634   second.erase();
635 }
636 
637 void mlir::collapseParallelLoops(
638     scf::ParallelOp loops, ArrayRef<std::vector<unsigned>> combinedDimensions) {
639   OpBuilder outsideBuilder(loops);
640   Location loc = loops.getLoc();
641 
642   // Presort combined dimensions.
643   auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
644   for (auto &dims : sortedDimensions)
645     std::sort(dims.begin(), dims.end());
646 
647   // Normalize ParallelOp's iteration pattern.
648   SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
649       normalizedUpperBounds;
650   for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
651     OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody());
652     auto resultBounds =
653         normalizeLoop(outsideBuilder, insideLoopBuilder, loc,
654                       loops.getLowerBound()[i], loops.getUpperBound()[i],
655                       loops.getStep()[i], loops.getBody()->getArgument(i));
656 
657     normalizedLowerBounds.push_back(resultBounds.lowerBound);
658     normalizedUpperBounds.push_back(resultBounds.upperBound);
659     normalizedSteps.push_back(resultBounds.step);
660   }
661 
662   // Combine iteration spaces.
663   SmallVector<Value, 3> lowerBounds, upperBounds, steps;
664   auto cst0 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 0);
665   auto cst1 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
666   for (unsigned i = 0, e = sortedDimensions.size(); i < e; ++i) {
667     Value newUpperBound = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
668     for (auto idx : sortedDimensions[i]) {
669       newUpperBound = outsideBuilder.create<arith::MulIOp>(
670           loc, newUpperBound, normalizedUpperBounds[idx]);
671     }
672     lowerBounds.push_back(cst0);
673     steps.push_back(cst1);
674     upperBounds.push_back(newUpperBound);
675   }
676 
677   // Create new ParallelLoop with conversions to the original induction values.
678   // The loop below uses divisions to get the relevant range of values in the
679   // new induction value that represent each range of the original induction
680   // value. The remainders then determine based on that range, which iteration
681   // of the original induction value this represents. This is a normalized value
682   // that is un-normalized already by the previous logic.
683   auto newPloop = outsideBuilder.create<scf::ParallelOp>(
684       loc, lowerBounds, upperBounds, steps,
685       [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
686         for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
687           Value previous = ploopIVs[i];
688           unsigned numberCombinedDimensions = combinedDimensions[i].size();
689           // Iterate over all except the last induction value.
690           for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
691             unsigned idx = combinedDimensions[i][j];
692 
693             // Determine the current induction value's current loop iteration
694             Value iv = insideBuilder.create<arith::RemSIOp>(
695                 loc, previous, normalizedUpperBounds[idx]);
696             replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
697                                        loops.getRegion());
698 
699             // Remove the effect of the current induction value to prepare for
700             // the next value.
701             previous = insideBuilder.create<arith::DivSIOp>(
702                 loc, previous, normalizedUpperBounds[idx]);
703           }
704 
705           // The final induction value is just the remaining value.
706           unsigned idx = combinedDimensions[i][0];
707           replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
708                                      previous, loops.getRegion());
709         }
710       });
711 
712   // Replace the old loop with the new loop.
713   loops.getBody()->back().erase();
714   newPloop.getBody()->getOperations().splice(
715       Block::iterator(newPloop.getBody()->back()),
716       loops.getBody()->getOperations());
717   loops.erase();
718 }
719 
720 // Hoist the ops within `outer` that appear before `inner`.
721 // Such ops include the ops that have been introduced by parametric tiling.
722 // Ops that come from triangular loops (i.e. that belong to the program slice
723 // rooted at `outer`) and ops that have side effects cannot be hoisted.
724 // Return failure when any op fails to hoist.
725 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
726   SetVector<Operation *> forwardSlice;
727   getForwardSlice(
728       outer.getInductionVar(), &forwardSlice,
729       [&inner](Operation *op) { return op != inner.getOperation(); });
730   LogicalResult status = success();
731   SmallVector<Operation *, 8> toHoist;
732   for (auto &op : outer.getBody()->without_terminator()) {
733     // Stop when encountering the inner loop.
734     if (&op == inner.getOperation())
735       break;
736     // Skip over non-hoistable ops.
737     if (forwardSlice.count(&op) > 0) {
738       status = failure();
739       continue;
740     }
741     // Skip intermediate scf::ForOp, these are not considered a failure.
742     if (isa<scf::ForOp>(op))
743       continue;
744     // Skip other ops with regions.
745     if (op.getNumRegions() > 0) {
746       status = failure();
747       continue;
748     }
749     // Skip if op has side effects.
750     // TODO: loads to immutable memory regions are ok.
751     if (!MemoryEffectOpInterface::hasNoEffect(&op)) {
752       status = failure();
753       continue;
754     }
755     toHoist.push_back(&op);
756   }
757   auto *outerForOp = outer.getOperation();
758   for (auto *op : toHoist)
759     op->moveBefore(outerForOp);
760   return status;
761 }
762 
763 // Traverse the interTile and intraTile loops and try to hoist ops such that
764 // bands of perfectly nested loops are isolated.
765 // Return failure if either perfect interTile or perfect intraTile bands cannot
766 // be formed.
767 static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
768   LogicalResult status = success();
769   const Loops &interTile = tileLoops.first;
770   const Loops &intraTile = tileLoops.second;
771   auto size = interTile.size();
772   assert(size == intraTile.size());
773   if (size <= 1)
774     return success();
775   for (unsigned s = 1; s < size; ++s)
776     status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
777                                : failure();
778   for (unsigned s = 1; s < size; ++s)
779     status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
780                                : failure();
781   return status;
782 }
783 
784 /// Collect perfectly nested loops starting from `rootForOps`.  Loops are
785 /// perfectly nested if each loop is the first and only non-terminator operation
786 /// in the parent loop.  Collect at most `maxLoops` loops and append them to
787 /// `forOps`.
788 template <typename T>
789 static void getPerfectlyNestedLoopsImpl(
790     SmallVectorImpl<T> &forOps, T rootForOp,
791     unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
792   for (unsigned i = 0; i < maxLoops; ++i) {
793     forOps.push_back(rootForOp);
794     Block &body = rootForOp.getRegion().front();
795     if (body.begin() != std::prev(body.end(), 2))
796       return;
797 
798     rootForOp = dyn_cast<T>(&body.front());
799     if (!rootForOp)
800       return;
801   }
802 }
803 
804 static Loops stripmineSink(scf::ForOp forOp, Value factor,
805                            ArrayRef<scf::ForOp> targets) {
806   auto originalStep = forOp.getStep();
807   auto iv = forOp.getInductionVar();
808 
809   OpBuilder b(forOp);
810   forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
811 
812   Loops innerLoops;
813   for (auto t : targets) {
814     // Save information for splicing ops out of t when done
815     auto begin = t.getBody()->begin();
816     auto nOps = t.getBody()->getOperations().size();
817 
818     // Insert newForOp before the terminator of `t`.
819     auto b = OpBuilder::atBlockTerminator((t.getBody()));
820     Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
821     Value less = b.create<arith::CmpIOp>(t.getLoc(), arith::CmpIPredicate::slt,
822                                          forOp.getUpperBound(), stepped);
823     Value ub = b.create<arith::SelectOp>(t.getLoc(), less,
824                                          forOp.getUpperBound(), stepped);
825 
826     // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
827     auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
828     newForOp.getBody()->getOperations().splice(
829         newForOp.getBody()->getOperations().begin(),
830         t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
831     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
832                                newForOp.getRegion());
833 
834     innerLoops.push_back(newForOp);
835   }
836 
837   return innerLoops;
838 }
839 
840 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
841 // Returns the new for operation, nested immediately under `target`.
842 template <typename SizeType>
843 static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
844                                 scf::ForOp target) {
845   // TODO: Use cheap structural assertions that targets are nested under
846   // forOp and that targets are not nested under each other when DominanceInfo
847   // exposes the capability. It seems overkill to construct a whole function
848   // dominance tree at this point.
849   auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
850   assert(res.size() == 1 && "Expected 1 inner forOp");
851   return res[0];
852 }
853 
854 SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
855                                  ArrayRef<Value> sizes,
856                                  ArrayRef<scf::ForOp> targets) {
857   SmallVector<SmallVector<scf::ForOp, 8>, 8> res;
858   SmallVector<scf::ForOp, 8> currentTargets(targets.begin(), targets.end());
859   for (auto it : llvm::zip(forOps, sizes)) {
860     auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
861     res.push_back(step);
862     currentTargets = step;
863   }
864   return res;
865 }
866 
867 Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
868                  scf::ForOp target) {
869   SmallVector<scf::ForOp, 8> res;
870   for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
871     assert(loops.size() == 1);
872     res.push_back(loops[0]);
873   }
874   return res;
875 }
876 
877 Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
878   // Collect perfectly nested loops.  If more size values provided than nested
879   // loops available, truncate `sizes`.
880   SmallVector<scf::ForOp, 4> forOps;
881   forOps.reserve(sizes.size());
882   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
883   if (forOps.size() < sizes.size())
884     sizes = sizes.take_front(forOps.size());
885 
886   return ::tile(forOps, sizes, forOps.back());
887 }
888 
889 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
890                                    scf::ForOp root) {
891   getPerfectlyNestedLoopsImpl(nestedLoops, root);
892 }
893 
894 TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
895                                        ArrayRef<int64_t> sizes) {
896   // Collect perfectly nested loops.  If more size values provided than nested
897   // loops available, truncate `sizes`.
898   SmallVector<scf::ForOp, 4> forOps;
899   forOps.reserve(sizes.size());
900   getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
901   if (forOps.size() < sizes.size())
902     sizes = sizes.take_front(forOps.size());
903 
904   // Compute the tile sizes such that i-th outer loop executes size[i]
905   // iterations.  Given that the loop current executes
906   //   numIterations = ceildiv((upperBound - lowerBound), step)
907   // iterations, we need to tile with size ceildiv(numIterations, size[i]).
908   SmallVector<Value, 4> tileSizes;
909   tileSizes.reserve(sizes.size());
910   for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
911     assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
912 
913     auto forOp = forOps[i];
914     OpBuilder builder(forOp);
915     auto loc = forOp.getLoc();
916     Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(),
917                                                forOp.getLowerBound());
918     Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
919     Value iterationsPerBlock =
920         ceilDivPositive(builder, loc, numIterations, sizes[i]);
921     tileSizes.push_back(iterationsPerBlock);
922   }
923 
924   // Call parametric tiling with the given sizes.
925   auto intraTile = tile(forOps, tileSizes, forOps.back());
926   TileLoops tileLoops = std::make_pair(forOps, intraTile);
927 
928   // TODO: for now we just ignore the result of band isolation.
929   // In the future, mapping decisions may be impacted by the ability to
930   // isolate perfectly nested bands.
931   (void)tryIsolateBands(tileLoops);
932 
933   return tileLoops;
934 }
935