1 //===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/LoopUtils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/Utils.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
20 #include "mlir/Dialect/Affine/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/IR/BlockAndValueMapping.h"
24 #include "mlir/IR/IntegerSet.h"
25 #include "mlir/Support/MathExtras.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "mlir/Transforms/RegionUtils.h"
28 #include "llvm/ADT/MapVector.h"
29 #include "llvm/ADT/SmallPtrSet.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 #define DEBUG_TYPE "LoopUtils"
34 
35 using namespace mlir;
36 using llvm::SmallMapVector;
37 
38 namespace {
39 // This structure is to pass and return sets of loop parameters without
40 // confusing the order.
41 struct LoopParams {
42   Value lowerBound;
43   Value upperBound;
44   Value step;
45 };
46 } // namespace
47 
48 /// Computes the cleanup loop lower bound of the loop being unrolled with
49 /// the specified unroll factor; this bound will also be upper bound of the main
50 /// part of the unrolled loop. Computes the bound as an AffineMap with its
51 /// operands or a null map when the trip count can't be expressed as an affine
52 /// expression.
53 static void
54 getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
55                          AffineMap &cleanupLbMap,
56                          SmallVectorImpl<Value> &cleanupLbOperands) {
57   AffineMap tripCountMap;
58   SmallVector<Value, 4> tripCountOperands;
59   getTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands);
60   // Trip count can't be computed.
61   if (!tripCountMap) {
62     cleanupLbMap = AffineMap();
63     return;
64   }
65 
66   OpBuilder b(forOp);
67   auto lbMap = forOp.getLowerBoundMap();
68   auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap,
69                                     forOp.getLowerBoundOperands());
70 
71   // For each upper bound expr, get the range.
72   // Eg: affine.for %i = lb to min (ub1, ub2),
73   // where tripCountExprs yield (tr1, tr2), we create affine.apply's:
74   // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all
75   // these affine.apply's make up the cleanup loop lower bound.
76   SmallVector<AffineExpr, 4> bumpExprs(tripCountMap.getNumResults());
77   SmallVector<Value, 4> bumpValues(tripCountMap.getNumResults());
78   int64_t step = forOp.getStep();
79   for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) {
80     auto tripCountExpr = tripCountMap.getResult(i);
81     bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step;
82     auto bumpMap = AffineMap::get(tripCountMap.getNumDims(),
83                                   tripCountMap.getNumSymbols(), bumpExprs[i]);
84     bumpValues[i] =
85         b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands);
86   }
87 
88   SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults());
89   for (unsigned i = 0, e = bumpExprs.size(); i < e; i++)
90     newUbExprs[i] = b.getAffineDimExpr(0) + b.getAffineDimExpr(i + 1);
91 
92   cleanupLbOperands.clear();
93   cleanupLbOperands.push_back(lb);
94   cleanupLbOperands.append(bumpValues.begin(), bumpValues.end());
95   cleanupLbMap = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs,
96                                 b.getContext());
97   // Simplify the cleanupLbMap + cleanupLbOperands.
98   fullyComposeAffineMapAndOperands(&cleanupLbMap, &cleanupLbOperands);
99   cleanupLbMap = simplifyAffineMap(cleanupLbMap);
100   canonicalizeMapAndOperands(&cleanupLbMap, &cleanupLbOperands);
101   // Remove any affine.apply's that became dead from the simplification above.
102   for (auto v : bumpValues)
103     if (v.use_empty())
104       v.getDefiningOp()->erase();
105 
106   if (lb.use_empty())
107     lb.erase();
108 }
109 
110 /// Helper to replace uses of loop carried values (iter_args) and loop
111 /// yield values while promoting single iteration affine.for ops.
112 static void replaceIterArgsAndYieldResults(AffineForOp forOp) {
113   // Replace uses of iter arguments with iter operands (initial values).
114   auto iterOperands = forOp.getIterOperands();
115   auto iterArgs = forOp.getRegionIterArgs();
116   for (auto e : llvm::zip(iterOperands, iterArgs))
117     std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
118 
119   // Replace uses of loop results with the values yielded by the loop.
120   auto outerResults = forOp.getResults();
121   auto innerResults = forOp.getBody()->getTerminator()->getOperands();
122   for (auto e : llvm::zip(outerResults, innerResults))
123     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
124 }
125 
126 /// Promotes the loop body of a forOp to its containing block if the forOp
127 /// was known to have a single iteration.
128 // TODO: extend this for arbitrary affine bounds.
129 LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
130   Optional<uint64_t> tripCount = getConstantTripCount(forOp);
131   if (!tripCount || tripCount.getValue() != 1)
132     return failure();
133 
134   if (forOp.getLowerBoundMap().getNumResults() != 1)
135     return failure();
136 
137   // Replaces all IV uses to its single iteration value.
138   auto iv = forOp.getInductionVar();
139   auto *parentBlock = forOp->getBlock();
140   if (!iv.use_empty()) {
141     if (forOp.hasConstantLowerBound()) {
142       OpBuilder topBuilder(forOp->getParentOfType<FuncOp>().getBody());
143       auto constOp = topBuilder.create<arith::ConstantIndexOp>(
144           forOp.getLoc(), forOp.getConstantLowerBound());
145       iv.replaceAllUsesWith(constOp);
146     } else {
147       auto lbOperands = forOp.getLowerBoundOperands();
148       auto lbMap = forOp.getLowerBoundMap();
149       OpBuilder builder(forOp);
150       if (lbMap == builder.getDimIdentityMap()) {
151         // No need of generating an affine.apply.
152         iv.replaceAllUsesWith(lbOperands[0]);
153       } else {
154         auto affineApplyOp =
155             builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
156         iv.replaceAllUsesWith(affineApplyOp);
157       }
158     }
159   }
160 
161   replaceIterArgsAndYieldResults(forOp);
162 
163   // Move the loop body operations, except for its terminator, to the loop's
164   // containing block.
165   forOp.getBody()->back().erase();
166   parentBlock->getOperations().splice(Block::iterator(forOp),
167                                       forOp.getBody()->getOperations());
168   forOp.erase();
169   return success();
170 }
171 
172 /// Generates an affine.for op with the specified lower and upper bounds
173 /// while generating the right IV remappings to realize shifts for operations in
174 /// its body. The operations that go into the loop body are specified in
175 /// opGroupQueue starting from the specified offset, and in that order. The
176 /// first element of the pair specifies the shift applied to that group of
177 /// operations; the shift is multiplied by the loop step before being applied.
178 /// Returns nullptr if the generated loop simplifies to a single iteration one.
179 static AffineForOp generateShiftedLoop(
180     AffineMap lbMap, AffineMap ubMap,
181     const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> &opGroupQueue,
182     unsigned offset, AffineForOp srcForOp, OpBuilder b) {
183   auto lbOperands = srcForOp.getLowerBoundOperands();
184   auto ubOperands = srcForOp.getUpperBoundOperands();
185 
186   assert(lbMap.getNumInputs() == lbOperands.size());
187   assert(ubMap.getNumInputs() == ubOperands.size());
188 
189   auto loopChunk = b.create<AffineForOp>(srcForOp.getLoc(), lbOperands, lbMap,
190                                          ubOperands, ubMap, srcForOp.getStep());
191   auto loopChunkIV = loopChunk.getInductionVar();
192   auto srcIV = srcForOp.getInductionVar();
193 
194   BlockAndValueMapping operandMap;
195 
196   auto bodyBuilder = OpBuilder::atBlockTerminator(loopChunk.getBody());
197   for (auto it = opGroupQueue.begin() + offset, e = opGroupQueue.end(); it != e;
198        ++it) {
199     uint64_t shift = it->first;
200     auto ops = it->second;
201     // All 'same shift' operations get added with their operands being
202     // remapped to results of cloned operations, and their IV used remapped.
203     // Generate the remapping if the shift is not zero: remappedIV = newIV -
204     // shift.
205     if (!srcIV.use_empty() && shift != 0) {
206       auto ivRemap = bodyBuilder.create<AffineApplyOp>(
207           srcForOp.getLoc(),
208           bodyBuilder.getSingleDimShiftAffineMap(
209               -static_cast<int64_t>(srcForOp.getStep() * shift)),
210           loopChunkIV);
211       operandMap.map(srcIV, ivRemap);
212     } else {
213       operandMap.map(srcIV, loopChunkIV);
214     }
215     for (auto *op : ops)
216       bodyBuilder.clone(*op, operandMap);
217   };
218   if (succeeded(promoteIfSingleIteration(loopChunk)))
219     return AffineForOp();
220   return loopChunk;
221 }
222 
223 // The skewing of operations with respect to one another can be used for
224 // example to allow overlap of asynchronous operations (such as DMA
225 // communication) with computation, or just relative shifting of operations
226 // for better register reuse, locality or parallelism. As such, the shifts are
227 // typically expected to be at most of the order of the number of operations.
228 // This method should not be used as a substitute for loop distribution/fission.
229 // This method uses an algorithm// in time linear in the number of operations
230 // in the body of the for loop - (using the 'sweep line' paradigm). This method
231 // asserts preservation of SSA dominance. A check for that as well as that for
232 // memory-based dependence preservation check rests with the users of this
233 // method.
234 LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
235                                         ArrayRef<uint64_t> shifts,
236                                         bool unrollPrologueEpilogue) {
237   assert(forOp.getBody()->getOperations().size() == shifts.size() &&
238          "too few/many shifts");
239   if (forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
240     return success();
241 
242   // If the trip counts aren't constant, we would need versioning and
243   // conditional guards (or context information to prevent such versioning). The
244   // better way to pipeline for such loops is to first tile them and extract
245   // constant trip count "full tiles" before applying this.
246   auto mayBeConstTripCount = getConstantTripCount(forOp);
247   if (!mayBeConstTripCount.hasValue()) {
248     LLVM_DEBUG(forOp.emitRemark("non-constant trip count loop not handled"));
249     return success();
250   }
251   uint64_t tripCount = mayBeConstTripCount.getValue();
252 
253   assert(isOpwiseShiftValid(forOp, shifts) &&
254          "shifts will lead to an invalid transformation\n");
255 
256   int64_t step = forOp.getStep();
257 
258   unsigned numChildOps = shifts.size();
259 
260   // Do a linear time (counting) sort for the shifts.
261   uint64_t maxShift = *std::max_element(shifts.begin(), shifts.end());
262   if (maxShift >= numChildOps) {
263     // Large shifts are not the typical use case.
264     forOp.emitWarning("not shifting because shifts are unrealistically large");
265     return success();
266   }
267 
268   // An array of operation groups sorted by shift amount; each group has all
269   // operations with the same shift in the order in which they appear in the
270   // body of the 'affine.for' op.
271   std::vector<std::vector<Operation *>> sortedOpGroups(maxShift + 1);
272   unsigned pos = 0;
273   for (auto &op : forOp.getBody()->without_terminator()) {
274     auto shift = shifts[pos++];
275     sortedOpGroups[shift].push_back(&op);
276   }
277 
278   // Unless the shifts have a specific pattern (which actually would be the
279   // common use case), prologue and epilogue are not meaningfully defined.
280   // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
281   // loop generated as the prologue and the last as epilogue and unroll these
282   // fully.
283   AffineForOp prologue, epilogue;
284 
285   // Do a sweep over the sorted shifts while storing open groups in a
286   // vector, and generating loop portions as necessary during the sweep. A block
287   // of operations is paired with its shift.
288   std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> opGroupQueue;
289 
290   auto origLbMap = forOp.getLowerBoundMap();
291   uint64_t lbShift = 0;
292   OpBuilder b(forOp);
293   for (uint64_t d = 0, e = sortedOpGroups.size(); d < e; ++d) {
294     // If nothing is shifted by d, continue.
295     if (sortedOpGroups[d].empty())
296       continue;
297     if (!opGroupQueue.empty()) {
298       assert(d > 0 &&
299              "Queue expected to be empty when the first block is found");
300       // The interval for which the loop needs to be generated here is:
301       // [lbShift, min(lbShift + tripCount, d)) and the body of the
302       // loop needs to have all operations in opQueue in that order.
303       AffineForOp res;
304       if (lbShift + tripCount * step < d * step) {
305         res = generateShiftedLoop(
306             b.getShiftedAffineMap(origLbMap, lbShift),
307             b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
308             opGroupQueue, /*offset=*/0, forOp, b);
309         // Entire loop for the queued op groups generated, empty it.
310         opGroupQueue.clear();
311         lbShift += tripCount * step;
312       } else {
313         res = generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
314                                   b.getShiftedAffineMap(origLbMap, d),
315                                   opGroupQueue, /*offset=*/0, forOp, b);
316         lbShift = d * step;
317       }
318 
319       if (res) {
320         // Simplify/canonicalize the affine.for.
321         RewritePatternSet patterns(res.getContext());
322         AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
323         bool erased;
324         (void)applyOpPatternsAndFold(res, std::move(patterns), &erased);
325 
326         if (!erased && !prologue)
327           prologue = res;
328         if (!erased)
329           epilogue = res;
330       }
331     } else {
332       // Start of first interval.
333       lbShift = d * step;
334     }
335     // Augment the list of operations that get into the current open interval.
336     opGroupQueue.emplace_back(d, sortedOpGroups[d]);
337   }
338 
339   // Those operations groups left in the queue now need to be processed (FIFO)
340   // and their loops completed.
341   for (unsigned i = 0, e = opGroupQueue.size(); i < e; ++i) {
342     uint64_t ubShift = (opGroupQueue[i].first + tripCount) * step;
343     epilogue = generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
344                                    b.getShiftedAffineMap(origLbMap, ubShift),
345                                    opGroupQueue, /*offset=*/i, forOp, b);
346     lbShift = ubShift;
347     if (!prologue)
348       prologue = epilogue;
349   }
350 
351   // Erase the original for op.
352   forOp.erase();
353 
354   if (unrollPrologueEpilogue && prologue)
355     (void)loopUnrollFull(prologue);
356   if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
357     (void)loopUnrollFull(epilogue);
358 
359   return success();
360 }
361 
362 /// Checks the legality of tiling of a hyper-rectangular loop nest by simply
363 /// checking if there is a 'negative' dependence in the memrefs present in
364 /// the loop nest. If yes then tiling is invalid.
365 static bool
366 checkTilingLegalityImpl(MutableArrayRef<mlir::AffineForOp> origLoops) {
367   assert(!origLoops.empty() && "no original loops provided");
368 
369   // We first find out all dependences we intend to check.
370   SmallVector<Operation *, 8> loadAndStoreOps;
371   origLoops[0]->walk([&](Operation *op) {
372     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
373       loadAndStoreOps.push_back(op);
374   });
375 
376   unsigned numOps = loadAndStoreOps.size();
377   unsigned numLoops = origLoops.size();
378   FlatAffineValueConstraints dependenceConstraints;
379   for (unsigned d = 1; d <= numLoops + 1; ++d) {
380     for (unsigned i = 0; i < numOps; ++i) {
381       Operation *srcOp = loadAndStoreOps[i];
382       MemRefAccess srcAccess(srcOp);
383       for (unsigned j = 0; j < numOps; ++j) {
384         Operation *dstOp = loadAndStoreOps[j];
385         MemRefAccess dstAccess(dstOp);
386 
387         SmallVector<DependenceComponent, 2> depComps;
388         dependenceConstraints.reset();
389         DependenceResult result = checkMemrefAccessDependence(
390             srcAccess, dstAccess, d, &dependenceConstraints, &depComps);
391 
392         // Skip if there is no dependence in this case.
393         if (!hasDependence(result))
394           continue;
395 
396         // Check whether there is any negative direction vector in the
397         // dependence components found above, which means that dependence is
398         // violated by the default hyper-rect tiling method.
399         LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated "
400                                    "for dependence at depth: "
401                                 << Twine(d) << " between:\n";);
402         LLVM_DEBUG(srcAccess.opInst->dump(););
403         LLVM_DEBUG(dstAccess.opInst->dump(););
404         for (unsigned k = 0, e = depComps.size(); k < e; k++) {
405           DependenceComponent depComp = depComps[k];
406           if (depComp.lb.hasValue() && depComp.ub.hasValue() &&
407               depComp.lb.getValue() < depComp.ub.getValue() &&
408               depComp.ub.getValue() < 0) {
409             LLVM_DEBUG(llvm::dbgs()
410                        << "Dependence component lb = "
411                        << Twine(depComp.lb.getValue())
412                        << " ub = " << Twine(depComp.ub.getValue())
413                        << " is negative  at depth: " << Twine(d)
414                        << " and thus violates the legality rule.\n");
415             return false;
416           }
417         }
418       }
419     }
420   }
421 
422   return true;
423 }
424 
425 /// Checks whether hyper-rectangular loop tiling of the nest
426 /// represented by `origLoops` is valid. The validity condition is from Irigoin
427 /// and Triolet, which states that two tiles cannot depend on each other. We
428 /// simplify such condition to just checking whether there is any negative
429 /// dependence direction, since we have the prior knowledge that the tiling
430 /// results will be hyper-rectangles, which are scheduled in the
431 /// lexicographically increasing order on the vector of loop indices. This
432 /// function will return failure when any dependence component is negative along
433 /// any of `origLoops`.
434 LogicalResult
435 checkTilingLegality(MutableArrayRef<mlir::AffineForOp> origLoops) {
436   return success(checkTilingLegalityImpl(origLoops));
437 }
438 
439 /// Checks whether a loop nest is hyper-rectangular or not.
440 LogicalResult checkIfHyperRectangular(MutableArrayRef<AffineForOp> input) {
441   FlatAffineValueConstraints cst;
442   SmallVector<Operation *, 8> ops(input.begin(), input.end());
443   // 0-d or 1-d is trivially hyper-rectangular.
444   if (input.size() <= 1)
445     return success();
446   if (failed(getIndexSet(ops, &cst))) {
447     LLVM_DEBUG(llvm::dbgs() << "Index set computation failed!\n");
448     return failure();
449   }
450   if (!cst.isHyperRectangular(0, input.size())) {
451     LLVM_DEBUG(llvm::dbgs()
452                << "Non-hyperrectangular nests not supported for tiling!\n");
453     return failure();
454   }
455   return success();
456 }
457 
458 /// Check if the input nest is supported for tiling and whether tiling would be
459 /// legal or not.
460 template <typename t>
461 LogicalResult performPreTilingChecks(MutableArrayRef<AffineForOp> input,
462                                      ArrayRef<t> tileSizes) {
463   assert(input.size() == tileSizes.size() && "Too few/many tile sizes");
464 
465   if (llvm::any_of(input,
466                    [](AffineForOp op) { return op.getNumResults() > 0; })) {
467     LLVM_DEBUG(llvm::dbgs()
468                << "Cannot tile nest where a loop has yield values\n");
469     return failure();
470   }
471 
472   // Check if the supplied `for` ops are all successively nested.
473   if (!isPerfectlyNested(input)) {
474     LLVM_DEBUG(llvm::dbgs() << "input loops not perfectly nested");
475     return failure();
476   }
477 
478   if (failed(checkIfHyperRectangular(input)))
479     return failure();
480 
481   // Check if tiling is legal.
482   if (failed(checkTilingLegality(input))) {
483     input[0].emitRemark("tiling code is illegal due to dependences");
484     return failure();
485   }
486 
487   return success();
488 }
489 
490 /// Move the loop body of AffineForOp 'src' from 'src' into the specified
491 /// location in destination's body, ignoring the terminator.
492 static void moveLoopBodyImpl(AffineForOp src, AffineForOp dest,
493                              Block::iterator loc) {
494   auto &ops = src.getBody()->getOperations();
495   dest.getBody()->getOperations().splice(loc, ops, ops.begin(),
496                                          std::prev(ops.end()));
497 }
498 
499 /// Move the loop body of AffineForOp 'src' from 'src' to the start of dest
500 /// body.
501 void moveLoopBody(AffineForOp src, AffineForOp dest) {
502   moveLoopBodyImpl(src, dest, dest.getBody()->begin());
503 }
504 
505 /// Constructs tiled loop nest, without setting the loop bounds and move the
506 /// body of the original loop nest to the tiled loop nest.
507 void constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops,
508                             AffineForOp rootAffineForOp, unsigned width,
509                             MutableArrayRef<AffineForOp> tiledLoops) {
510   Location loc = rootAffineForOp.getLoc();
511 
512   // The outermost among the loops as we add more..
513   Operation *topLoop = rootAffineForOp.getOperation();
514   AffineForOp innermostPointLoop;
515 
516   // Add intra-tile (or point) loops.
517   for (unsigned i = 0; i < width; i++) {
518     OpBuilder b(topLoop);
519     // Loop bounds will be set later.
520     AffineForOp pointLoop = b.create<AffineForOp>(loc, 0, 0);
521     pointLoop.getBody()->getOperations().splice(
522         pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
523         topLoop);
524     tiledLoops[2 * width - 1 - i] = pointLoop;
525     topLoop = pointLoop.getOperation();
526     if (i == 0)
527       innermostPointLoop = pointLoop;
528   }
529 
530   // Add tile space loops;
531   for (unsigned i = width; i < 2 * width; i++) {
532     OpBuilder b(topLoop);
533     // Loop bounds will be set later.
534     AffineForOp tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0);
535     tileSpaceLoop.getBody()->getOperations().splice(
536         tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
537         topLoop);
538     tiledLoops[2 * width - i - 1] = tileSpaceLoop;
539     topLoop = tileSpaceLoop.getOperation();
540   }
541 
542   // Move the loop body of the original nest to the new one.
543   moveLoopBody(origLoops.back(), innermostPointLoop);
544 }
545 
546 /// Set lower and upper bounds of intra-tile loops for parametric tiling.
547 //  TODO: Handle non-constant lower bounds.
548 static void setIntraTileBoundsParametric(OpBuilder &b, AffineForOp origLoop,
549                                          AffineForOp newInterTileLoop,
550                                          AffineForOp newIntraTileLoop,
551                                          Value tileSize) {
552   // The lower bound for the intra-tile loop is represented by an affine map
553   // as (%i, %t0)->((%i - %origlb) * %t0 + %origlb). Similarly, the upper bound
554   // for the intra-tile loop is represented by an affine map as (%i, %t0)->((%i
555   // - %origlb) * %t0) + (%t0 * %origLoopStep) + %origlb), where %i is loop IV
556   // of the corresponding inter-tile loop, %t0 is the corresponding tiling
557   // parameter, %origlb is lower bound and %origLoopStep is the loop step of the
558   // corresponding inter-tile loop.
559 
560   assert(origLoop.hasConstantLowerBound() &&
561          "expected input loops to have constant lower bound.");
562 
563   // Get lower bound of original loop as an affine expression.
564   AffineExpr origLowerBoundExpr;
565   origLowerBoundExpr =
566       b.getAffineConstantExpr(origLoop.getConstantLowerBound());
567 
568   // Add dim operands from original lower/upper bound.
569   SmallVector<Value, 4> lbOperands, ubOperands;
570   AffineBound lb = origLoop.getLowerBound();
571   AffineBound ub = origLoop.getUpperBound();
572   lbOperands.reserve(lb.getNumOperands() + 2);
573   ubOperands.reserve(ub.getNumOperands() + 2);
574   AffineMap origLbMap = lb.getMap();
575   AffineMap origUbMap = ub.getMap();
576   for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j)
577     lbOperands.push_back(lb.getOperand(j));
578   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
579     ubOperands.push_back(ub.getOperand(j));
580 
581   // Add a new dim operand in lb/ubOperands corresponding to the origLoop
582   // IV.
583   lbOperands.push_back(newInterTileLoop.getInductionVar());
584   ubOperands.push_back(newInterTileLoop.getInductionVar());
585 
586   // Get loop IV as an affine expression for lower/upper bound. Size of
587   // lb/ubOperands is guaranteed to be atleast one.
588   AffineExpr lbLoopIvExpr = b.getAffineDimExpr(lbOperands.size() - 1);
589   AffineExpr ubLoopIvExpr = b.getAffineDimExpr(ubOperands.size() - 1);
590 
591   // Add symbol operands from original lower/upper bound.
592   for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
593     lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
594   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
595     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
596 
597   // Add a new symbol operand which is the tile size for this loop.
598   lbOperands.push_back(tileSize);
599   ubOperands.push_back(tileSize);
600 
601   SmallVector<AffineExpr, 4> lbBoundExprs;
602   SmallVector<AffineExpr, 4> ubBoundExprs;
603   lbBoundExprs.reserve(origLbMap.getNumResults());
604   ubBoundExprs.reserve(origUbMap.getNumResults());
605 
606   // Get tiling parameter as an affine expression for lb/ub.
607   AffineExpr lbTileParameter = b.getAffineSymbolExpr(origLbMap.getNumSymbols());
608   AffineExpr ubTileParameter = b.getAffineSymbolExpr(origUbMap.getNumSymbols());
609 
610   // Insert lb as inter-tile ((loop IV - origlb) * tilingParameter) + origlb.
611   lbBoundExprs.push_back(
612       ((lbLoopIvExpr - origLowerBoundExpr) * lbTileParameter) +
613       origLowerBoundExpr);
614 
615   // Get the origLoopStep as an affine expression.
616   AffineExpr origLoopStep = b.getAffineConstantExpr(origLoop.getStep());
617 
618   // Insert ub as inter-tile ((loop IV - origlb) * tilingParameter) +
619   // (tilingParameter * origLoopStep) + origlb.
620   ubBoundExprs.push_back(
621       ((ubLoopIvExpr - origLowerBoundExpr) * ubTileParameter) +
622       (ubTileParameter * origLoopStep) + origLowerBoundExpr);
623 
624   ubBoundExprs.append(origUbMap.getResults().begin(),
625                       origUbMap.getResults().end());
626 
627   AffineMap lbMap =
628       AffineMap::get(origLbMap.getNumDims() + 1, origLbMap.getNumSymbols() + 1,
629                      lbBoundExprs, b.getContext());
630   newIntraTileLoop.setLowerBound(lbOperands, lbMap);
631 
632   AffineMap ubMap =
633       AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols() + 1,
634                      ubBoundExprs, b.getContext());
635   newIntraTileLoop.setUpperBound(ubOperands, ubMap);
636 
637   // Original loop step must be preserved.
638   newIntraTileLoop.setStep(origLoop.getStep());
639 }
640 
641 /// Set lower and upper bounds of inter-tile loops for parametric tiling.
642 //  TODO: Handle non-constant lower bounds.
643 static void setInterTileBoundsParametric(OpBuilder &b, AffineForOp origLoop,
644                                          AffineForOp newLoop, Value tileSize) {
645   OperandRange newLbOperands = origLoop.getLowerBoundOperands();
646 
647   // The lower bounds for inter-tile loops are same as the corresponding lower
648   // bounds of original loops.
649   newLoop.setLowerBound(newLbOperands, origLoop.getLowerBoundMap());
650 
651   // The new upper bound map for inter-tile loops, assuming constant lower
652   // bounds, are now originalLowerBound + ceildiv((originalUpperBound -
653   // originalLowerBound), tiling parameter); where tiling parameter is the
654   // respective tile size for that loop. For e.g. if the original ubmap was
655   // ()->(1024), the new map will be
656   // ()[s0]->(ceildiv((1024 -lb) % s0)), where s0 is the tiling parameter.
657   // Therefore a new symbol operand is inserted in the map and the result
658   // expression is overwritten.
659 
660   assert(origLoop.hasConstantLowerBound() &&
661          "expected input loops to have constant lower bound.");
662 
663   // Get lower bound of original loop as an affine expression.
664   AffineExpr origLowerBoundExpr;
665   origLowerBoundExpr =
666       b.getAffineConstantExpr(origLoop.getConstantLowerBound());
667 
668   // Add dim operands from original upper bound.
669   SmallVector<Value, 4> ubOperands;
670   AffineBound ub = origLoop.getUpperBound();
671   ubOperands.reserve(ub.getNumOperands() + 1);
672   AffineMap origUbMap = ub.getMap();
673   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
674     ubOperands.push_back(ub.getOperand(j));
675 
676   // Add symbol operands from original upper bound.
677   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
678     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
679 
680   // Add a new symbol operand which is the tile size for this loop.
681   ubOperands.push_back(tileSize);
682 
683   // Get tiling parameter as an affine expression.
684   AffineExpr tileParameter = b.getAffineSymbolExpr(origUbMap.getNumSymbols());
685 
686   SmallVector<AffineExpr, 4> boundExprs;
687   boundExprs.reserve(origUbMap.getNumResults());
688   int64_t origUpperBound;
689   AffineExpr origUpperBoundExpr;
690 
691   // If upper bound for the original loop is constant, then the constant can
692   // be obtained as an affine expression straight away.
693   if (origLoop.hasConstantUpperBound()) {
694     origUpperBound = origLoop.getConstantUpperBound();
695 
696     // Get original constant upper bound as an affine expression.
697     origUpperBoundExpr = b.getAffineConstantExpr(origUpperBound);
698 
699     // Insert the bound as originalLowerBoundceildiv((originalUpperBound -
700     // originalLowerBound), tilingParameter).
701     boundExprs.push_back(
702         origLowerBoundExpr +
703         (origUpperBoundExpr - origLowerBoundExpr).ceilDiv(tileParameter));
704   } else {
705     // If upper bound for the original loop is not constant then two cases
706     // are possible, although there handeling is the same, 1.) The result of
707     // ubmap has only one result expression. For e.g.
708     //    affine.for %i = 5 to %ub
709     //
710     // A symbol operand is added which represents the tiling parameter. The
711     // new loop bounds here will be like ()[s0, s1] -> ((s0 - 5) ceildiv s1 + 5)
712     // where 's0' is the original upper bound and 's1' is the tiling
713     // parameter. 2.) When ubMap has more than one result expression. For e.g.
714     //    #map0 = affine_map<()[s0, s1] -> (s0, s1)
715     //    affine.for %i = 5 to min #map0()[%s0, %s1]
716     //
717     // A symbol operand is added which represents the tiling parameter. The
718     // new loop bounds will be like ()[s0, s1, s2] -> ((s0 - 5) ceildiv s2 + 5,
719     // (s1 -5) ceildiv s2 + 5), where s2 is the tiling parameter.
720 
721     // Insert the bounds as originalLowerBound + ceildiv((originalUpperBound -
722     // originalLowerBound), tilingParameter).
723     for (AffineExpr origUpperBoundExpr : origUbMap.getResults())
724       boundExprs.push_back(
725           origLowerBoundExpr +
726           (origUpperBoundExpr - origLowerBoundExpr).ceilDiv(tileParameter));
727   }
728 
729   AffineMap ubMap =
730       AffineMap::get(origUbMap.getNumDims(), origUbMap.getNumSymbols() + 1,
731                      boundExprs, b.getContext());
732   newLoop.setUpperBound(ubOperands, ubMap);
733 
734   // Original loop step must be preserved.
735   newLoop.setStep(origLoop.getStep());
736 }
737 
738 /// Constructs and sets new loop bounds after tiling for the case of
739 /// hyper-rectangular index sets, where the bounds of one dimension do not
740 /// depend on other dimensions and tiling parameters are captured from SSA
741 /// values. Bounds of each dimension can thus be treated independently,
742 /// and deriving the new bounds is much simpler and faster than for the case of
743 /// tiling arbitrary polyhedral shapes.
744 static void constructParametricallyTiledIndexSetHyperRect(
745     MutableArrayRef<AffineForOp> origLoops,
746     MutableArrayRef<AffineForOp> newLoops, ArrayRef<Value> tileSizes) {
747   assert(!origLoops.empty() && "expected atleast one loop in band");
748   assert(origLoops.size() == tileSizes.size() &&
749          "expected tiling parameter for each loop in band.");
750 
751   OpBuilder b(origLoops[0].getOperation());
752   unsigned width = origLoops.size();
753 
754   // Set bounds for tile space loops.
755   for (unsigned i = 0; i < width; ++i) {
756     setInterTileBoundsParametric(b, origLoops[i], newLoops[i], tileSizes[i]);
757   }
758 
759   // Set bounds for intra-tile loops.
760   for (unsigned i = 0; i < width; ++i) {
761     setIntraTileBoundsParametric(b, origLoops[i], newLoops[i],
762                                  newLoops[i + width], tileSizes[i]);
763   }
764 }
765 
766 /// Constructs and sets new loop bounds after tiling for the case of
767 /// hyper-rectangular index sets, where the bounds of one dimension do not
768 /// depend on other dimensions. Bounds of each dimension can thus be treated
769 /// independently, and deriving the new bounds is much simpler and faster
770 /// than for the case of tiling arbitrary polyhedral shapes.
771 static void
772 constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,
773                                 MutableArrayRef<AffineForOp> newLoops,
774                                 ArrayRef<unsigned> tileSizes) {
775   assert(!origLoops.empty());
776   assert(origLoops.size() == tileSizes.size());
777 
778   OpBuilder b(origLoops[0].getOperation());
779   unsigned width = origLoops.size();
780 
781   // Bounds for tile space loops.
782   for (unsigned i = 0; i < width; i++) {
783     OperandRange newLbOperands = origLoops[i].getLowerBoundOperands();
784     OperandRange newUbOperands = origLoops[i].getUpperBoundOperands();
785     newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap());
786     newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap());
787     // If the step size of original loop is x and tileSize is y then after
788     // tiling the tile space loops' step size becomes x*y.
789     newLoops[i].setStep(tileSizes[i] * origLoops[i].getStep());
790   }
791   // Bounds for intra-tile loops.
792   for (unsigned i = 0; i < width; i++) {
793     int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]);
794     Optional<uint64_t> mayBeConstantCount = getConstantTripCount(origLoops[i]);
795     // The lower bound is just the tile-space loop.
796     AffineMap lbMap = b.getDimIdentityMap();
797     newLoops[width + i].setLowerBound(
798         /*operands=*/newLoops[i].getInductionVar(), lbMap);
799     // The step sizes of intra-tile loops is just the original loops' step size.
800     newLoops[width + i].setStep(origLoops[i].getStep());
801 
802     // Set the upper bound.
803     if (mayBeConstantCount && mayBeConstantCount.getValue() < tileSizes[i]) {
804       // Trip count is less than the tile size: upper bound is lower bound +
805       // trip count * stepSize.
806       AffineMap ubMap = b.getSingleDimShiftAffineMap(
807           mayBeConstantCount.getValue() * origLoops[i].getStep());
808       newLoops[width + i].setUpperBound(
809           /*operands=*/newLoops[i].getInductionVar(), ubMap);
810     } else if (largestDiv % tileSizes[i] != 0) {
811       // Intra-tile loop ii goes from i to min(i + tileSize * stepSize, ub_i).
812       // Construct the upper bound map; the operands are the original operands
813       // with 'i' (tile-space loop) appended to it. The new upper bound map is
814       // the original one with an additional expression i + tileSize * stepSize
815       // appended.
816 
817       // Add dim operands from original upper bound.
818       SmallVector<Value, 4> ubOperands;
819       AffineBound ub = origLoops[i].getUpperBound();
820       ubOperands.reserve(ub.getNumOperands() + 1);
821       AffineMap origUbMap = ub.getMap();
822       for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
823         ubOperands.push_back(ub.getOperand(j));
824 
825       // Add dim operand for new loop upper bound.
826       ubOperands.push_back(newLoops[i].getInductionVar());
827 
828       // Add symbol operands from original upper bound.
829       for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
830         ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
831 
832       SmallVector<AffineExpr, 4> boundExprs;
833       boundExprs.reserve(1 + origUbMap.getNumResults());
834       AffineExpr dim = b.getAffineDimExpr(origUbMap.getNumDims());
835       // The new upper bound map is the original one with an additional
836       // expression i + tileSize * stepSize (of original loop) appended.
837       boundExprs.push_back(dim + tileSizes[i] * origLoops[i].getStep());
838       boundExprs.append(origUbMap.getResults().begin(),
839                         origUbMap.getResults().end());
840       AffineMap ubMap =
841           AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols(),
842                          boundExprs, b.getContext());
843       newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap);
844     } else {
845       // No need of the min expression.
846       AffineExpr dim = b.getAffineDimExpr(0);
847       AffineMap ubMap =
848           AffineMap::get(1, 0, dim + tileSizes[i] * origLoops[i].getStep());
849       newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap);
850     }
851   }
852 }
853 
854 /// Tiles the specified band of perfectly nested loops creating tile-space loops
855 /// and intra-tile loops. A band is a contiguous set of loops.
856 //  TODO: handle non hyper-rectangular spaces.
857 LogicalResult
858 mlir::tilePerfectlyNested(MutableArrayRef<AffineForOp> input,
859                           ArrayRef<unsigned> tileSizes,
860                           SmallVectorImpl<AffineForOp> *tiledNest) {
861   if (input.empty())
862     return success();
863 
864   if (failed(performPreTilingChecks(input, tileSizes)))
865     return failure();
866 
867   MutableArrayRef<AffineForOp> origLoops = input;
868   AffineForOp rootAffineForOp = origLoops[0];
869 
870   // Note that width is at least one since the band isn't empty.
871   unsigned width = input.size();
872   SmallVector<AffineForOp, 6> tiledLoops(2 * width);
873 
874   // Construct a tiled loop nest without setting their bounds. Bounds are
875   // set later.
876   constructTiledLoopNest(origLoops, rootAffineForOp, width, tiledLoops);
877 
878   SmallVector<Value, 8> origLoopIVs;
879   extractForInductionVars(input, &origLoopIVs);
880 
881   // Set loop bounds for the tiled loop nest.
882   constructTiledIndexSetHyperRect(origLoops, tiledLoops, tileSizes);
883 
884   // Replace original IVs with intra-tile loop IVs.
885   for (unsigned i = 0; i < width; i++)
886     origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar());
887 
888   // Erase the old loop nest.
889   rootAffineForOp.erase();
890 
891   if (tiledNest)
892     *tiledNest = std::move(tiledLoops);
893 
894   return success();
895 }
896 
897 /// Tiles the specified band of perfectly nested loops creating tile-space
898 /// loops and intra-tile loops, using SSA values as tiling parameters. A band
899 /// is a contiguous set of loops.
900 //  TODO: handle non hyper-rectangular spaces.
901 LogicalResult
902 mlir::tilePerfectlyNestedParametric(MutableArrayRef<AffineForOp> input,
903                                     ArrayRef<Value> tileSizes,
904                                     SmallVectorImpl<AffineForOp> *tiledNest) {
905   if (input.empty())
906     return success();
907 
908   if (failed(performPreTilingChecks(input, tileSizes)))
909     return failure();
910 
911   MutableArrayRef<AffineForOp> origLoops = input;
912   AffineForOp rootAffineForOp = origLoops[0];
913   unsigned width = input.size();
914   SmallVector<AffineForOp, 6> tiledLoops(2 * width);
915 
916   // Construct a tiled loop nest without setting their bounds. Bounds are
917   // set later.
918   constructTiledLoopNest(origLoops, rootAffineForOp, width, tiledLoops);
919 
920   SmallVector<Value, 8> origLoopIVs;
921   extractForInductionVars(input, &origLoopIVs);
922 
923   // Set loop bounds for the tiled loop nest.
924   constructParametricallyTiledIndexSetHyperRect(origLoops, tiledLoops,
925                                                 tileSizes);
926 
927   // Replace original IVs with intra-tile loop IVs.
928   for (unsigned i = 0; i < width; i++)
929     origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar());
930 
931   // Erase the old loop nest.
932   rootAffineForOp.erase();
933 
934   if (tiledNest)
935     *tiledNest = std::move(tiledLoops);
936 
937   return success();
938 }
939 
940 /// Get perfectly nested sequence of loops starting at root of loop nest
941 /// (the first op being another AffineFor, and the second op - a terminator).
942 /// A loop is perfectly nested iff: the first op in the loop's body is another
943 /// AffineForOp, and the second op is a terminator).
944 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
945                                    AffineForOp root) {
946   for (unsigned i = 0; i < std::numeric_limits<unsigned>::max(); ++i) {
947     nestedLoops.push_back(root);
948     Block &body = root.getRegion().front();
949     if (body.begin() != std::prev(body.end(), 2))
950       return;
951 
952     root = dyn_cast<AffineForOp>(&body.front());
953     if (!root)
954       return;
955   }
956 }
957 
958 /// Identify valid and profitable bands of loops to tile. This is currently just
959 /// a temporary placeholder to test the mechanics of tiled code generation.
960 /// Returns all maximal outermost perfect loop nests to tile.
961 void mlir::getTileableBands(FuncOp f,
962                             std::vector<SmallVector<AffineForOp, 6>> *bands) {
963   // Get maximal perfect nest of 'affine.for' insts starting from root
964   // (inclusive).
965   for (AffineForOp forOp : f.getOps<AffineForOp>()) {
966     SmallVector<AffineForOp, 6> band;
967     getPerfectlyNestedLoops(band, forOp);
968     bands->push_back(band);
969   }
970 }
971 
972 /// Unrolls this loop completely.
973 LogicalResult mlir::loopUnrollFull(AffineForOp forOp) {
974   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
975   if (mayBeConstantTripCount.hasValue()) {
976     uint64_t tripCount = mayBeConstantTripCount.getValue();
977     if (tripCount == 0)
978       return success();
979     if (tripCount == 1)
980       return promoteIfSingleIteration(forOp);
981     return loopUnrollByFactor(forOp, tripCount);
982   }
983   return failure();
984 }
985 
986 /// Unrolls this loop by the specified factor or by the trip count (if constant)
987 /// whichever is lower.
988 LogicalResult mlir::loopUnrollUpToFactor(AffineForOp forOp,
989                                          uint64_t unrollFactor) {
990   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
991   if (mayBeConstantTripCount.hasValue() &&
992       mayBeConstantTripCount.getValue() < unrollFactor)
993     return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue());
994   return loopUnrollByFactor(forOp, unrollFactor);
995 }
996 
997 /// Generates unrolled copies of AffineForOp 'loopBodyBlock', with associated
998 /// 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap 'forOpIV' for each
999 /// unrolled body. If specified, annotates the Ops in each unrolled iteration
1000 /// using annotateFn.
1001 static void generateUnrolledLoop(
1002     Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
1003     function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
1004     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
1005     ValueRange iterArgs, ValueRange yieldedValues) {
1006   // Builder to insert unrolled bodies just before the terminator of the body of
1007   // 'forOp'.
1008   auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
1009 
1010   if (!annotateFn)
1011     annotateFn = [](unsigned, Operation *, OpBuilder) {};
1012 
1013   // Keep a pointer to the last non-terminator operation in the original block
1014   // so that we know what to clone (since we are doing this in-place).
1015   Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
1016 
1017   // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
1018   SmallVector<Value, 4> lastYielded(yieldedValues);
1019 
1020   for (unsigned i = 1; i < unrollFactor; i++) {
1021     BlockAndValueMapping operandMap;
1022 
1023     // Prepare operand map.
1024     operandMap.map(iterArgs, lastYielded);
1025 
1026     // If the induction variable is used, create a remapping to the value for
1027     // this unrolled instance.
1028     if (!forOpIV.use_empty()) {
1029       Value ivUnroll = ivRemapFn(i, forOpIV, builder);
1030       operandMap.map(forOpIV, ivUnroll);
1031     }
1032 
1033     // Clone the original body of 'forOp'.
1034     for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
1035       Operation *clonedOp = builder.clone(*it, operandMap);
1036       annotateFn(i, clonedOp, builder);
1037     }
1038 
1039     // Update yielded values.
1040     for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
1041       lastYielded[i] = operandMap.lookup(yieldedValues[i]);
1042   }
1043 
1044   // Make sure we annotate the Ops in the original body. We do this last so that
1045   // any annotations are not copied into the cloned Ops above.
1046   for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
1047     annotateFn(0, &*it, builder);
1048 
1049   // Update operands of the yield statement.
1050   loopBodyBlock->getTerminator()->setOperands(lastYielded);
1051 }
1052 
1053 /// Helper to generate cleanup loop for unroll or unroll-and-jam when the trip
1054 /// count is not a multiple of `unrollFactor`.
1055 static LogicalResult generateCleanupLoopForUnroll(AffineForOp forOp,
1056                                                   uint64_t unrollFactor) {
1057   // Insert the cleanup loop right after 'forOp'.
1058   OpBuilder builder(forOp->getBlock(), std::next(Block::iterator(forOp)));
1059   auto cleanupForOp = cast<AffineForOp>(builder.clone(*forOp));
1060 
1061   // Update uses of `forOp` results. `cleanupForOp` should use `forOp` result
1062   // and produce results for the original users of `forOp` results.
1063   auto results = forOp.getResults();
1064   auto cleanupResults = cleanupForOp.getResults();
1065   auto cleanupIterOperands = cleanupForOp.getIterOperands();
1066 
1067   for (auto e : llvm::zip(results, cleanupResults, cleanupIterOperands)) {
1068     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
1069     cleanupForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
1070   }
1071 
1072   AffineMap cleanupMap;
1073   SmallVector<Value, 4> cleanupOperands;
1074   getCleanupLoopLowerBound(forOp, unrollFactor, cleanupMap, cleanupOperands);
1075   if (!cleanupMap)
1076     return failure();
1077 
1078   cleanupForOp.setLowerBound(cleanupOperands, cleanupMap);
1079   // Promote the loop body up if this has turned into a single iteration loop.
1080   (void)promoteIfSingleIteration(cleanupForOp);
1081 
1082   // Adjust upper bound of the original loop; this is the same as the lower
1083   // bound of the cleanup loop.
1084   forOp.setUpperBound(cleanupOperands, cleanupMap);
1085   return success();
1086 }
1087 
1088 /// Unrolls this loop by the specified factor. Returns success if the loop
1089 /// is successfully unrolled.
1090 LogicalResult mlir::loopUnrollByFactor(
1091     AffineForOp forOp, uint64_t unrollFactor,
1092     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
1093   assert(unrollFactor > 0 && "unroll factor should be positive");
1094 
1095   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1096   if (unrollFactor == 1) {
1097     if (mayBeConstantTripCount.hasValue() &&
1098         mayBeConstantTripCount.getValue() == 1 &&
1099         failed(promoteIfSingleIteration(forOp)))
1100       return failure();
1101     return success();
1102   }
1103 
1104   // Nothing in the loop body other than the terminator.
1105   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
1106     return success();
1107 
1108   // If the trip count is lower than the unroll factor, no unrolled body.
1109   // TODO: option to specify cleanup loop unrolling.
1110   if (mayBeConstantTripCount.hasValue() &&
1111       mayBeConstantTripCount.getValue() < unrollFactor)
1112     return failure();
1113 
1114   // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
1115   if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
1116     // Loops where the lower bound is a max expression or the upper bound is
1117     // a min expression and the trip count doesn't divide the unroll factor
1118     // can't be unrolled since the lower bound of the cleanup loop in such cases
1119     // cannot be expressed as an affine function or a max over affine functions.
1120     if (forOp.getLowerBoundMap().getNumResults() != 1 ||
1121         forOp.getUpperBoundMap().getNumResults() != 1)
1122       return failure();
1123     if (failed(generateCleanupLoopForUnroll(forOp, unrollFactor)))
1124       assert(false && "cleanup loop lower bound map for single result lower "
1125                       "and upper bound maps can always be determined");
1126   }
1127 
1128   ValueRange iterArgs(forOp.getRegionIterArgs());
1129   auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
1130 
1131   // Scale the step of loop being unrolled by unroll factor.
1132   int64_t step = forOp.getStep();
1133   forOp.setStep(step * unrollFactor);
1134   generateUnrolledLoop(
1135       forOp.getBody(), forOp.getInductionVar(), unrollFactor,
1136       [&](unsigned i, Value iv, OpBuilder b) {
1137         // iv' = iv + i * step
1138         auto d0 = b.getAffineDimExpr(0);
1139         auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
1140         return b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, iv);
1141       },
1142       /*annotateFn=*/annotateFn,
1143       /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues);
1144 
1145   // Promote the loop body up if this has turned into a single iteration loop.
1146   (void)promoteIfSingleIteration(forOp);
1147   return success();
1148 }
1149 
1150 LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp,
1151                                             uint64_t unrollJamFactor) {
1152   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1153   if (mayBeConstantTripCount.hasValue() &&
1154       mayBeConstantTripCount.getValue() < unrollJamFactor)
1155     return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue());
1156   return loopUnrollJamByFactor(forOp, unrollJamFactor);
1157 }
1158 
1159 /// Check if all control operands of all loops are defined outside of `forOp`
1160 /// and return false if not.
1161 static bool areInnerBoundsInvariant(AffineForOp forOp) {
1162   auto walkResult = forOp.walk([&](AffineForOp aForOp) {
1163     for (auto controlOperand : aForOp.getControlOperands()) {
1164       if (!forOp.isDefinedOutsideOfLoop(controlOperand))
1165         return WalkResult::interrupt();
1166     }
1167     return WalkResult::advance();
1168   });
1169   return !walkResult.wasInterrupted();
1170 }
1171 
1172 // Gathers all maximal sub-blocks of operations that do not themselves
1173 // include a for op (a operation could have a descendant for op though
1174 // in its tree).  Ignore the block terminators.
1175 struct JamBlockGatherer {
1176   // Store iterators to the first and last op of each sub-block found.
1177   std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
1178 
1179   // This is a linear time walk.
1180   void walk(Operation *op) {
1181     for (auto &region : op->getRegions())
1182       for (auto &block : region)
1183         walk(block);
1184   }
1185 
1186   void walk(Block &block) {
1187     for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
1188       auto subBlockStart = it;
1189       while (it != e && !isa<AffineForOp>(&*it))
1190         ++it;
1191       if (it != subBlockStart)
1192         subBlocks.emplace_back(subBlockStart, std::prev(it));
1193       // Process all for ops that appear next.
1194       while (it != e && isa<AffineForOp>(&*it))
1195         walk(&*it++);
1196     }
1197   }
1198 };
1199 
1200 /// Unrolls and jams this loop by the specified factor.
1201 LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
1202                                           uint64_t unrollJamFactor) {
1203   assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
1204 
1205   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1206   if (unrollJamFactor == 1) {
1207     if (mayBeConstantTripCount.hasValue() &&
1208         mayBeConstantTripCount.getValue() == 1 &&
1209         failed(promoteIfSingleIteration(forOp)))
1210       return failure();
1211     return success();
1212   }
1213 
1214   // Nothing in the loop body other than the terminator.
1215   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
1216     return success();
1217 
1218   // If the trip count is lower than the unroll jam factor, no unroll jam.
1219   if (mayBeConstantTripCount.hasValue() &&
1220       mayBeConstantTripCount.getValue() < unrollJamFactor) {
1221     LLVM_DEBUG(llvm::dbgs() << "[failed] trip count < unroll-jam factor\n");
1222     return failure();
1223   }
1224 
1225   // If any control operand of any inner loop of `forOp` is defined within
1226   // `forOp`, no unroll jam.
1227   if (!areInnerBoundsInvariant(forOp))
1228     return failure();
1229 
1230   // Gather all sub-blocks to jam upon the loop being unrolled.
1231   JamBlockGatherer jbg;
1232   jbg.walk(forOp);
1233   auto &subBlocks = jbg.subBlocks;
1234 
1235   // Collect loops with iter_args.
1236   SmallVector<AffineForOp, 4> loopsWithIterArgs;
1237   forOp.walk([&](AffineForOp aForOp) {
1238     if (aForOp.getNumIterOperands() > 0)
1239       loopsWithIterArgs.push_back(aForOp);
1240   });
1241 
1242   // Get supported reductions to be used for creating reduction ops at the end.
1243   SmallVector<LoopReduction> reductions;
1244   if (forOp.getNumIterOperands() > 0)
1245     getSupportedReductions(forOp, reductions);
1246 
1247   // Generate the cleanup loop if trip count isn't a multiple of
1248   // unrollJamFactor.
1249   if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) {
1250     // Loops where the lower bound is a max expression or the upper bound is
1251     // a min expression and the trip count doesn't divide the unroll factor
1252     // can't be unrolled since the lower bound of the cleanup loop in such cases
1253     // cannot be expressed as an affine function or a max over affine functions.
1254     if (forOp.getLowerBoundMap().getNumResults() != 1 ||
1255         forOp.getUpperBoundMap().getNumResults() != 1)
1256       return failure();
1257     if (failed(generateCleanupLoopForUnroll(forOp, unrollJamFactor)))
1258       assert(false && "cleanup loop lower bound map for single result lower "
1259                       "and upper bound maps can always be determined");
1260   }
1261 
1262   // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
1263   // iteration. There are (`unrollJamFactor` - 1) iterations.
1264   SmallVector<BlockAndValueMapping, 4> operandMaps(unrollJamFactor - 1);
1265 
1266   // For any loop with iter_args, replace it with a new loop that has
1267   // `unrollJamFactor` copies of its iterOperands, iter_args and yield
1268   // operands.
1269   SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
1270   OpBuilder builder(forOp.getContext());
1271   for (AffineForOp oldForOp : loopsWithIterArgs) {
1272     SmallVector<Value, 4> dupIterOperands, dupIterArgs, dupYieldOperands;
1273     ValueRange oldIterOperands = oldForOp.getIterOperands();
1274     ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
1275     ValueRange oldYieldOperands =
1276         cast<AffineYieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
1277     // Get additional iterOperands, iterArgs, and yield operands. We will
1278     // fix iterOperands and yield operands after cloning of sub-blocks.
1279     for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1280       dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
1281       dupIterArgs.append(oldIterArgs.begin(), oldIterArgs.end());
1282       dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
1283     }
1284     // Create a new loop with additional iterOperands, iter_args and yield
1285     // operands. This new loop will take the loop body of the original loop.
1286     AffineForOp newForOp = mlir::replaceForOpWithNewYields(
1287         builder, oldForOp, dupIterOperands, dupYieldOperands, dupIterArgs);
1288     newLoopsWithIterArgs.push_back(newForOp);
1289     // `forOp` has been replaced with a new loop.
1290     if (oldForOp == forOp)
1291       forOp = newForOp;
1292     assert(oldForOp.use_empty() && "old for op should not have any user");
1293     oldForOp.erase();
1294     // Update `operandMaps` for `newForOp` iterArgs and results.
1295     ValueRange newIterArgs = newForOp.getRegionIterArgs();
1296     unsigned oldNumIterArgs = oldIterArgs.size();
1297     ValueRange newResults = newForOp.getResults();
1298     unsigned oldNumResults = newResults.size() / unrollJamFactor;
1299     assert(oldNumIterArgs == oldNumResults &&
1300            "oldNumIterArgs must be the same as oldNumResults");
1301     for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1302       for (unsigned j = 0; j < oldNumIterArgs; ++j) {
1303         // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
1304         // results. Update `operandMaps[i - 1]` to map old iterArgs and results
1305         // to those in the `i`th new set.
1306         operandMaps[i - 1].map(newIterArgs[j],
1307                                newIterArgs[i * oldNumIterArgs + j]);
1308         operandMaps[i - 1].map(newResults[j],
1309                                newResults[i * oldNumResults + j]);
1310       }
1311     }
1312   }
1313 
1314   // Scale the step of loop being unroll-jammed by the unroll-jam factor.
1315   int64_t step = forOp.getStep();
1316   forOp.setStep(step * unrollJamFactor);
1317 
1318   auto forOpIV = forOp.getInductionVar();
1319   // Unroll and jam (appends unrollJamFactor - 1 additional copies).
1320   for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1321     for (auto &subBlock : subBlocks) {
1322       // Builder to insert unroll-jammed bodies. Insert right at the end of
1323       // sub-block.
1324       OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
1325 
1326       // If the induction variable is used, create a remapping to the value for
1327       // this unrolled instance.
1328       if (!forOpIV.use_empty()) {
1329         // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
1330         auto d0 = builder.getAffineDimExpr(0);
1331         auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
1332         auto ivUnroll =
1333             builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
1334         operandMaps[i - 1].map(forOpIV, ivUnroll);
1335       }
1336       // Clone the sub-block being unroll-jammed.
1337       for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
1338         builder.clone(*it, operandMaps[i - 1]);
1339     }
1340     // Fix iterOperands and yield op operands of newly created loops.
1341     for (auto newForOp : newLoopsWithIterArgs) {
1342       unsigned oldNumIterOperands =
1343           newForOp.getNumIterOperands() / unrollJamFactor;
1344       unsigned numControlOperands = newForOp.getNumControlOperands();
1345       auto yieldOp = cast<AffineYieldOp>(newForOp.getBody()->getTerminator());
1346       unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
1347       assert(oldNumIterOperands == oldNumYieldOperands &&
1348              "oldNumIterOperands must be the same as oldNumYieldOperands");
1349       for (unsigned j = 0; j < oldNumIterOperands; ++j) {
1350         // The `i`th duplication of an old iterOperand or yield op operand
1351         // needs to be replaced with a mapped value from `operandMaps[i - 1]`
1352         // if such mapped value exists.
1353         newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
1354                             operandMaps[i - 1].lookupOrDefault(
1355                                 newForOp.getOperand(numControlOperands + j)));
1356         yieldOp.setOperand(
1357             i * oldNumYieldOperands + j,
1358             operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
1359       }
1360     }
1361   }
1362   if (forOp.getNumResults() > 0) {
1363     // Create reduction ops to combine every `unrollJamFactor` related results
1364     // into one value. For example, for %0:2 = affine.for ... and addf, we add
1365     // %1 = arith.addf %0#0, %0#1, and replace the following uses of %0#0 with
1366     // %1.
1367     builder.setInsertionPointAfter(forOp);
1368     auto loc = forOp.getLoc();
1369     unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
1370     for (LoopReduction &reduction : reductions) {
1371       unsigned pos = reduction.iterArgPosition;
1372       Value lhs = forOp.getResult(pos);
1373       Value rhs;
1374       SmallPtrSet<Operation *, 4> newOps;
1375       for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1376         rhs = forOp.getResult(i * oldNumResults + pos);
1377         // Create ops based on reduction type.
1378         lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs);
1379         if (!lhs)
1380           return failure();
1381         Operation *op = lhs.getDefiningOp();
1382         assert(op && "Reduction op should have been created");
1383         newOps.insert(op);
1384       }
1385       // Replace all uses except those in newly created reduction ops.
1386       forOp.getResult(pos).replaceAllUsesExcept(lhs, newOps);
1387     }
1388   }
1389 
1390   // Promote the loop body up if this has turned into a single iteration loop.
1391   (void)promoteIfSingleIteration(forOp);
1392   return success();
1393 }
1394 
1395 /// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is
1396 /// nested within 'forOpA' as the only non-terminator operation in its block.
1397 void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) {
1398   assert(&*forOpA.getBody()->begin() == forOpB.getOperation());
1399   auto &forOpABody = forOpA.getBody()->getOperations();
1400   auto &forOpBBody = forOpB.getBody()->getOperations();
1401 
1402   // 1) Splice forOpA's non-terminator operations (which is just forOpB) just
1403   // before forOpA (in ForOpA's parent's block) this should leave 'forOpA's
1404   // body containing only the terminator.
1405   forOpA->getBlock()->getOperations().splice(Block::iterator(forOpA),
1406                                              forOpABody, forOpABody.begin(),
1407                                              std::prev(forOpABody.end()));
1408   // 2) Splice forOpB's non-terminator operations into the beginning of forOpA's
1409   // body (this leaves forOpB's body containing only the terminator).
1410   forOpABody.splice(forOpABody.begin(), forOpBBody, forOpBBody.begin(),
1411                     std::prev(forOpBBody.end()));
1412   // 3) Splice forOpA into the beginning of forOpB's body.
1413   forOpBBody.splice(forOpBBody.begin(), forOpA->getBlock()->getOperations(),
1414                     Block::iterator(forOpA));
1415 }
1416 
1417 // Checks each dependence component against the permutation to see if the
1418 // desired loop interchange would violate dependences by making the
1419 // dependence component lexicographically negative.
1420 static bool checkLoopInterchangeDependences(
1421     const std::vector<SmallVector<DependenceComponent, 2>> &depCompsVec,
1422     ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) {
1423   // Invert permutation map.
1424   unsigned maxLoopDepth = loops.size();
1425   SmallVector<unsigned, 4> loopPermMapInv;
1426   loopPermMapInv.resize(maxLoopDepth);
1427   for (unsigned i = 0; i < maxLoopDepth; ++i)
1428     loopPermMapInv[loopPermMap[i]] = i;
1429 
1430   // Check each dependence component against the permutation to see if the
1431   // desired loop interchange permutation would make the dependence vectors
1432   // lexicographically negative.
1433   // Example 1: [-1, 1][0, 0]
1434   // Example 2: [0, 0][-1, 1]
1435   for (const auto &depComps : depCompsVec) {
1436     assert(depComps.size() >= maxLoopDepth);
1437     // Check if the first non-zero dependence component is positive.
1438     // This iterates through loops in the desired order.
1439     for (unsigned j = 0; j < maxLoopDepth; ++j) {
1440       unsigned permIndex = loopPermMapInv[j];
1441       assert(depComps[permIndex].lb.hasValue());
1442       int64_t depCompLb = depComps[permIndex].lb.getValue();
1443       if (depCompLb > 0)
1444         break;
1445       if (depCompLb < 0)
1446         return false;
1447     }
1448   }
1449   return true;
1450 }
1451 
1452 /// Checks if the loop interchange permutation 'loopPermMap' of the perfectly
1453 /// nested sequence of loops in 'loops' would violate dependences.
1454 bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
1455                                              ArrayRef<unsigned> loopPermMap) {
1456   // Gather dependence components for dependences between all ops in loop nest
1457   // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
1458   assert(loopPermMap.size() == loops.size());
1459   unsigned maxLoopDepth = loops.size();
1460   std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
1461   getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
1462   return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap);
1463 }
1464 
1465 /// Returns true if `loops` is a perfectly nested loop nest, where loops appear
1466 /// in it from outermost to innermost.
1467 bool LLVM_ATTRIBUTE_UNUSED
1468 mlir::isPerfectlyNested(ArrayRef<AffineForOp> loops) {
1469   assert(!loops.empty() && "no loops provided");
1470 
1471   // We already know that the block can't be empty.
1472   auto hasTwoElements = [](Block *block) {
1473     auto secondOpIt = std::next(block->begin());
1474     return secondOpIt != block->end() && &*secondOpIt == &block->back();
1475   };
1476 
1477   auto enclosingLoop = loops.front();
1478   for (auto loop : loops.drop_front()) {
1479     auto parentForOp = dyn_cast<AffineForOp>(loop->getParentOp());
1480     // parentForOp's body should be just this loop and the terminator.
1481     if (parentForOp != enclosingLoop || !hasTwoElements(parentForOp.getBody()))
1482       return false;
1483     enclosingLoop = loop;
1484   }
1485   return true;
1486 }
1487 
1488 // input[i] should move from position i -> permMap[i]. Returns the position in
1489 // `input` that becomes the new outermost loop.
1490 unsigned mlir::permuteLoops(MutableArrayRef<AffineForOp> input,
1491                             ArrayRef<unsigned> permMap) {
1492   assert(input.size() == permMap.size() && "invalid permutation map size");
1493   // Check whether the permutation spec is valid. This is a small vector - we'll
1494   // just sort and check if it's iota.
1495   SmallVector<unsigned, 4> checkPermMap(permMap.begin(), permMap.end());
1496   llvm::sort(checkPermMap);
1497   if (llvm::any_of(llvm::enumerate(checkPermMap),
1498                    [](const auto &en) { return en.value() != en.index(); }))
1499     assert(false && "invalid permutation map");
1500 
1501   // Nothing to do.
1502   if (input.size() < 2)
1503     return 0;
1504 
1505   assert(isPerfectlyNested(input) && "input not perfectly nested");
1506 
1507   // Compute the inverse mapping, invPermMap: since input[i] goes to position
1508   // permMap[i], position i of the permuted nest is at input[invPermMap[i]].
1509   SmallVector<std::pair<unsigned, unsigned>, 4> invPermMap;
1510   for (unsigned i = 0, e = input.size(); i < e; ++i)
1511     invPermMap.push_back({permMap[i], i});
1512   llvm::sort(invPermMap);
1513 
1514   // Move the innermost loop body to the loop that would be the innermost in the
1515   // permuted nest (only if the innermost loop is going to change).
1516   if (permMap.back() != input.size() - 1) {
1517     auto *destBody = input[invPermMap.back().second].getBody();
1518     auto *srcBody = input.back().getBody();
1519     destBody->getOperations().splice(destBody->begin(),
1520                                      srcBody->getOperations(), srcBody->begin(),
1521                                      std::prev(srcBody->end()));
1522   }
1523 
1524   // We'll move each loop in `input` in the reverse order so that its body is
1525   // empty when we are moving it; this incurs zero copies and no erasing.
1526   for (int i = input.size() - 1; i >= 0; --i) {
1527     // If this has to become the outermost loop after permutation, add it to the
1528     // parent block of the original root.
1529     if (permMap[i] == 0) {
1530       // If the root remains the same, nothing to do.
1531       if (i == 0)
1532         continue;
1533       // Make input[i] the new outermost loop moving it into parentBlock.
1534       auto *parentBlock = input[0]->getBlock();
1535       parentBlock->getOperations().splice(Block::iterator(input[0]),
1536                                           input[i]->getBlock()->getOperations(),
1537                                           Block::iterator(input[i]));
1538       continue;
1539     }
1540 
1541     // If the parent in the permuted order is the same as in the original,
1542     // nothing to do.
1543     unsigned parentPosInInput = invPermMap[permMap[i] - 1].second;
1544     if (i > 0 && static_cast<unsigned>(i - 1) == parentPosInInput)
1545       continue;
1546 
1547     // Move input[i] to its surrounding loop in the transformed nest.
1548     auto *destBody = input[parentPosInInput].getBody();
1549     destBody->getOperations().splice(destBody->begin(),
1550                                      input[i]->getBlock()->getOperations(),
1551                                      Block::iterator(input[i]));
1552   }
1553 
1554   return invPermMap[0].second;
1555 }
1556 
1557 // Sinks all sequential loops to the innermost levels (while preserving
1558 // relative order among them) and moves all parallel loops to the
1559 // outermost (while again preserving relative order among them).
1560 AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) {
1561   SmallVector<AffineForOp, 4> loops;
1562   getPerfectlyNestedLoops(loops, forOp);
1563   if (loops.size() < 2)
1564     return forOp;
1565 
1566   // Gather dependence components for dependences between all ops in loop nest
1567   // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
1568   unsigned maxLoopDepth = loops.size();
1569   std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
1570   getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
1571 
1572   // Mark loops as either parallel or sequential.
1573   SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
1574   for (auto &depComps : depCompsVec) {
1575     assert(depComps.size() >= maxLoopDepth);
1576     for (unsigned j = 0; j < maxLoopDepth; ++j) {
1577       DependenceComponent &depComp = depComps[j];
1578       assert(depComp.lb.hasValue() && depComp.ub.hasValue());
1579       if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0)
1580         isParallelLoop[j] = false;
1581     }
1582   }
1583 
1584   // Count the number of parallel loops.
1585   unsigned numParallelLoops = 0;
1586   for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i)
1587     if (isParallelLoop[i])
1588       ++numParallelLoops;
1589 
1590   // Compute permutation of loops that sinks sequential loops (and thus raises
1591   // parallel loops) while preserving relative order.
1592   SmallVector<unsigned, 4> loopPermMap(maxLoopDepth);
1593   unsigned nextSequentialLoop = numParallelLoops;
1594   unsigned nextParallelLoop = 0;
1595   for (unsigned i = 0; i < maxLoopDepth; ++i) {
1596     if (isParallelLoop[i]) {
1597       loopPermMap[i] = nextParallelLoop++;
1598     } else {
1599       loopPermMap[i] = nextSequentialLoop++;
1600     }
1601   }
1602 
1603   // Check if permutation 'loopPermMap' would violate dependences.
1604   if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap))
1605     return forOp;
1606   // Perform loop interchange according to permutation 'loopPermMap'.
1607   unsigned loopNestRootIndex = permuteLoops(loops, loopPermMap);
1608   return loops[loopNestRootIndex];
1609 }
1610 
1611 // Factors out common behavior to add a new `iv` (resp. `iv` + `offset`) to the
1612 // lower (resp. upper) loop bound. When called for both the lower and upper
1613 // bounds, the resulting IR resembles:
1614 //
1615 // ```mlir
1616 //    affine.for %i = max (`iv, ...) to min (`iv` + `offset`) {
1617 //      ...
1618 //    }
1619 // ```
1620 static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map,
1621                                 SmallVector<Value, 4> *operands,
1622                                 int64_t offset = 0) {
1623   auto bounds = llvm::to_vector<4>(map->getResults());
1624   bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset);
1625   operands->insert(operands->begin() + map->getNumDims(), iv);
1626   *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds,
1627                         b.getContext());
1628   canonicalizeMapAndOperands(map, operands);
1629 }
1630 
1631 // Stripmines `forOp` by `factor` and sinks it under each of the `targets`.
1632 // Stripmine-sink is a primitive building block for generalized tiling of
1633 // imperfectly nested loops.
1634 // This transformation is purely mechanical and does not check legality,
1635 // profitability or even structural correctness. It is the user's
1636 // responsibility to specify `targets` that are dominated by `forOp`.
1637 // Returns the new AffineForOps, one per `targets`, nested immediately under
1638 // each of the `targets`.
1639 static SmallVector<AffineForOp, 8>
1640 stripmineSink(AffineForOp forOp, uint64_t factor,
1641               ArrayRef<AffineForOp> targets) {
1642   auto originalStep = forOp.getStep();
1643   auto scaledStep = originalStep * factor;
1644   forOp.setStep(scaledStep);
1645 
1646   OpBuilder b(forOp->getBlock(), std::next(Block::iterator(forOp)));
1647 
1648   // Lower-bound map creation.
1649   auto lbMap = forOp.getLowerBoundMap();
1650   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1651   augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands);
1652 
1653   // Upper-bound map creation.
1654   auto ubMap = forOp.getUpperBoundMap();
1655   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1656   augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands,
1657                       /*offset=*/scaledStep);
1658 
1659   auto iv = forOp.getInductionVar();
1660   SmallVector<AffineForOp, 8> innerLoops;
1661   for (auto t : targets) {
1662     // Insert newForOp before the terminator of `t`.
1663     auto b = OpBuilder::atBlockTerminator(t.getBody());
1664     auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap,
1665                                           ubOperands, ubMap, originalStep);
1666     auto begin = t.getBody()->begin();
1667     // Skip terminator and `newForOp` which is just before the terminator.
1668     auto nOps = t.getBody()->getOperations().size() - 2;
1669     newForOp.getBody()->getOperations().splice(
1670         newForOp.getBody()->getOperations().begin(),
1671         t.getBody()->getOperations(), begin, std::next(begin, nOps));
1672     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1673                                newForOp.region());
1674     innerLoops.push_back(newForOp);
1675   }
1676 
1677   return innerLoops;
1678 }
1679 
1680 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1681 // Returns the new AffineForOps, nested immediately under `target`.
1682 template <typename SizeType>
1683 static AffineForOp stripmineSink(AffineForOp forOp, SizeType factor,
1684                                  AffineForOp target) {
1685   // TODO: Use cheap structural assertions that targets are nested under
1686   // forOp and that targets are not nested under each other when DominanceInfo
1687   // exposes the capability. It seems overkill to construct a whole function
1688   // dominance tree at this point.
1689   auto res = stripmineSink(forOp, factor, ArrayRef<AffineForOp>(target));
1690   assert(res.size() == 1 && "Expected 1 inner forOp");
1691   return res[0];
1692 }
1693 
1694 SmallVector<SmallVector<AffineForOp, 8>, 8>
1695 mlir::tile(ArrayRef<AffineForOp> forOps, ArrayRef<uint64_t> sizes,
1696            ArrayRef<AffineForOp> targets) {
1697   SmallVector<SmallVector<AffineForOp, 8>, 8> res;
1698   SmallVector<AffineForOp, 8> currentTargets(targets.begin(), targets.end());
1699   for (auto it : llvm::zip(forOps, sizes)) {
1700     auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1701     res.push_back(step);
1702     currentTargets = step;
1703   }
1704   return res;
1705 }
1706 
1707 SmallVector<AffineForOp, 8> mlir::tile(ArrayRef<AffineForOp> forOps,
1708                                        ArrayRef<uint64_t> sizes,
1709                                        AffineForOp target) {
1710   SmallVector<AffineForOp, 8> res;
1711   for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
1712     assert(loops.size() == 1);
1713     res.push_back(loops[0]);
1714   }
1715   return res;
1716 }
1717 
1718 LogicalResult mlir::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
1719   if (loops.size() < 2)
1720     return success();
1721 
1722   AffineForOp innermost = loops.back();
1723   AffineForOp outermost = loops.front();
1724   AffineBound ub = outermost.getUpperBound();
1725   AffineMap origUbMap = ub.getMap();
1726   Location loc = outermost.getLoc();
1727   OpBuilder builder(outermost);
1728   for (AffineForOp loop : loops) {
1729     // We only work on normalized loops.
1730     if (loop.getStep() != 1 || !loop.hasConstantLowerBound() ||
1731         loop.getConstantLowerBound() != 0)
1732       return failure();
1733   }
1734   SmallVector<Value, 4> upperBoundSymbols;
1735   SmallVector<Value, 4> ubOperands(ub.getOperands().begin(),
1736                                    ub.getOperands().end());
1737 
1738   // 1. Store the upper bound of the outermost loop in a variable.
1739   Value prev;
1740   if (!llvm::hasSingleElement(origUbMap.getResults()))
1741     prev = builder.create<AffineMinOp>(loc, origUbMap, ubOperands);
1742   else
1743     prev = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands);
1744   upperBoundSymbols.push_back(prev);
1745 
1746   // 2. Emit code computing the upper bound of the coalesced loop as product of
1747   // the number of iterations of all loops.
1748   for (AffineForOp loop : loops.drop_front()) {
1749     ub = loop.getUpperBound();
1750     origUbMap = ub.getMap();
1751     ubOperands = ub.getOperands();
1752     Value upperBound;
1753     // If upper bound map has more than one result, take their minimum.
1754     if (!llvm::hasSingleElement(origUbMap.getResults()))
1755       upperBound = builder.create<AffineMinOp>(loc, origUbMap, ubOperands);
1756     else
1757       upperBound = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands);
1758     upperBoundSymbols.push_back(upperBound);
1759     SmallVector<Value, 4> operands;
1760     operands.push_back(prev);
1761     operands.push_back(upperBound);
1762     // Maintain running product of loop upper bounds.
1763     prev = builder.create<AffineApplyOp>(
1764         loc,
1765         AffineMap::get(/*numDims=*/1,
1766                        /*numSymbols=*/1,
1767                        builder.getAffineDimExpr(0) *
1768                            builder.getAffineSymbolExpr(0)),
1769         operands);
1770   }
1771   // Set upper bound of the coalesced loop.
1772   AffineMap newUbMap = AffineMap::get(
1773       /*numDims=*/0,
1774       /*numSymbols=*/1, builder.getAffineSymbolExpr(0), builder.getContext());
1775   outermost.setUpperBound(prev, newUbMap);
1776 
1777   builder.setInsertionPointToStart(outermost.getBody());
1778 
1779   // 3. Remap induction variables. For each original loop, the value of the
1780   // induction variable can be obtained by dividing the induction variable of
1781   // the linearized loop by the total number of iterations of the loops nested
1782   // in it modulo the number of iterations in this loop (remove the values
1783   // related to the outer loops):
1784   //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
1785   // Compute these iteratively from the innermost loop by creating a "running
1786   // quotient" of division by the range.
1787   Value previous = outermost.getInductionVar();
1788   for (unsigned idx = loops.size(); idx > 0; --idx) {
1789     if (idx != loops.size()) {
1790       SmallVector<Value, 4> operands;
1791       operands.push_back(previous);
1792       operands.push_back(upperBoundSymbols[idx]);
1793       previous = builder.create<AffineApplyOp>(
1794           loc,
1795           AffineMap::get(
1796               /*numDims=*/1, /*numSymbols=*/1,
1797               builder.getAffineDimExpr(0).floorDiv(
1798                   builder.getAffineSymbolExpr(0))),
1799           operands);
1800     }
1801     // Modified value of the induction variables of the nested loops after
1802     // coalescing.
1803     Value inductionVariable;
1804     if (idx == 1) {
1805       inductionVariable = previous;
1806     } else {
1807       SmallVector<Value, 4> applyOperands;
1808       applyOperands.push_back(previous);
1809       applyOperands.push_back(upperBoundSymbols[idx - 1]);
1810       inductionVariable = builder.create<AffineApplyOp>(
1811           loc,
1812           AffineMap::get(
1813               /*numDims=*/1, /*numSymbols=*/1,
1814               builder.getAffineDimExpr(0) % builder.getAffineSymbolExpr(0)),
1815           applyOperands);
1816     }
1817     replaceAllUsesInRegionWith(loops[idx - 1].getInductionVar(),
1818                                inductionVariable, loops.back().region());
1819   }
1820 
1821   // 4. Move the operations from the innermost just above the second-outermost
1822   // loop, delete the extra terminator and the second-outermost loop.
1823   AffineForOp secondOutermostLoop = loops[1];
1824   innermost.getBody()->back().erase();
1825   outermost.getBody()->getOperations().splice(
1826       Block::iterator(secondOutermostLoop.getOperation()),
1827       innermost.getBody()->getOperations());
1828   secondOutermostLoop.erase();
1829   return success();
1830 }
1831 
1832 void mlir::mapLoopToProcessorIds(scf::ForOp forOp, ArrayRef<Value> processorId,
1833                                  ArrayRef<Value> numProcessors) {
1834   assert(processorId.size() == numProcessors.size());
1835   if (processorId.empty())
1836     return;
1837 
1838   OpBuilder b(forOp);
1839   Location loc(forOp.getLoc());
1840   AffineExpr lhs, rhs;
1841   bindSymbols(forOp.getContext(), lhs, rhs);
1842   auto mulMap = AffineMap::get(0, 2, lhs * rhs);
1843   auto addMap = AffineMap::get(0, 2, lhs + rhs);
1844 
1845   Value linearIndex = processorId.front();
1846   for (unsigned i = 1, e = processorId.size(); i < e; ++i) {
1847     auto mulApplyOp = b.create<AffineApplyOp>(
1848         loc, mulMap, ValueRange{linearIndex, numProcessors[i]});
1849     linearIndex = b.create<AffineApplyOp>(
1850         loc, addMap, ValueRange{mulApplyOp, processorId[i]});
1851   }
1852 
1853   auto mulApplyOp = b.create<AffineApplyOp>(
1854       loc, mulMap, ValueRange{linearIndex, forOp.getStep()});
1855   Value lb = b.create<AffineApplyOp>(
1856       loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()});
1857   forOp.setLowerBound(lb);
1858 
1859   Value step = forOp.getStep();
1860   for (auto numProcs : numProcessors)
1861     step = b.create<AffineApplyOp>(loc, mulMap, ValueRange{numProcs, step});
1862   forOp.setStep(step);
1863 }
1864 
1865 /// Given a memref region, determine the lowest depth at which transfers can be
1866 /// placed for it, and return the corresponding block, start and end positions
1867 /// in the block for placing incoming (read) and outgoing (write) copies
1868 /// respectively. The lowest depth depends on whether the region being accessed
1869 /// is hoistable with respect to one or more immediately surrounding loops.
1870 static void
1871 findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
1872                              Block::iterator &begin, Block::iterator &end,
1873                              Block **copyPlacementBlock,
1874                              Block::iterator *copyInPlacementStart,
1875                              Block::iterator *copyOutPlacementStart) {
1876   const auto *cst = region.getConstraints();
1877   SmallVector<Value, 4> symbols;
1878   cst->getValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols);
1879 
1880   SmallVector<AffineForOp, 4> enclosingFors;
1881   getLoopIVs(*block.begin(), &enclosingFors);
1882   // Walk up loop parents till we find an IV on which this region is
1883   // symbolic/variant.
1884   auto it = enclosingFors.rbegin();
1885   for (auto e = enclosingFors.rend(); it != e; ++it) {
1886     // TODO: also need to be checking this for regions symbols that
1887     // aren't loop IVs, whether we are within their resp. defs' dominance scope.
1888     if (llvm::is_contained(symbols, it->getInductionVar()))
1889       break;
1890   }
1891 
1892   if (it != enclosingFors.rbegin()) {
1893     auto lastInvariantIV = *std::prev(it);
1894     *copyInPlacementStart = Block::iterator(lastInvariantIV.getOperation());
1895     *copyOutPlacementStart = std::next(*copyInPlacementStart);
1896     *copyPlacementBlock = lastInvariantIV->getBlock();
1897   } else {
1898     *copyInPlacementStart = begin;
1899     *copyOutPlacementStart = end;
1900     *copyPlacementBlock = &block;
1901   }
1902 }
1903 
1904 // Info comprising stride and number of elements transferred every stride.
1905 struct StrideInfo {
1906   int64_t stride;
1907   int64_t numEltPerStride;
1908 };
1909 
1910 /// Returns striding information for a copy/transfer of this region with
1911 /// potentially multiple striding levels from outermost to innermost. For an
1912 /// n-dimensional region, there can be at most n-1 levels of striding
1913 /// successively nested.
1914 //  TODO: make this work with non-identity layout maps.
1915 static void getMultiLevelStrides(const MemRefRegion &region,
1916                                  ArrayRef<int64_t> bufferShape,
1917                                  SmallVectorImpl<StrideInfo> *strideInfos) {
1918   if (bufferShape.size() <= 1)
1919     return;
1920 
1921   int64_t numEltPerStride = 1;
1922   int64_t stride = 1;
1923   for (int d = bufferShape.size() - 1; d >= 1; d--) {
1924     int64_t dimSize = region.memref.getType().cast<MemRefType>().getDimSize(d);
1925     stride *= dimSize;
1926     numEltPerStride *= bufferShape[d];
1927     // A stride is needed only if the region has a shorter extent than the
1928     // memref along the dimension *and* has an extent greater than one along the
1929     // next major dimension.
1930     if (bufferShape[d] < dimSize && bufferShape[d - 1] > 1) {
1931       strideInfos->push_back({stride, numEltPerStride});
1932     }
1933   }
1934 }
1935 
1936 /// Generates a point-wise copy from/to `memref' to/from `fastMemRef' and
1937 /// returns the outermost AffineForOp of the copy loop nest. `lbMaps` and
1938 /// `ubMaps` along with `lbOperands` and `ubOperands` hold the lower and upper
1939 /// bound information for the copy loop nest. `fastBufOffsets` contain the
1940 /// expressions to be subtracted out from the respective copy loop iterators in
1941 /// order to index the fast buffer. If `copyOut' is true, generates a copy-out;
1942 /// otherwise a copy-in. Builder `b` should be set to the point the copy nest is
1943 /// inserted.
1944 //
1945 /// The copy-in nest is generated as follows as an example for a 2-d region:
1946 /// for x = ...
1947 ///   for y = ...
1948 ///     fast_buf[x - offset_x][y - offset_y] = memref[x][y]
1949 ///
1950 static AffineForOp
1951 generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
1952                       ArrayRef<AffineMap> lbMaps, ArrayRef<Value> lbOperands,
1953                       ArrayRef<AffineMap> ubMaps, ArrayRef<Value> ubOperands,
1954                       ArrayRef<AffineExpr> fastBufOffsets, bool isCopyOut,
1955                       OpBuilder b) {
1956   assert(llvm::all_of(lbMaps, [&](AffineMap lbMap) {
1957     return lbMap.getNumInputs() == lbOperands.size();
1958   }));
1959   assert(llvm::all_of(ubMaps, [&](AffineMap ubMap) {
1960     return ubMap.getNumInputs() == ubOperands.size();
1961   }));
1962 
1963   unsigned rank = memref.getType().cast<MemRefType>().getRank();
1964   assert(lbMaps.size() == rank && "wrong number of lb maps");
1965   assert(ubMaps.size() == rank && "wrong number of ub maps");
1966 
1967   SmallVector<Value, 4> memIndices;
1968   SmallVector<AffineExpr, 4> fastBufExprs;
1969   SmallVector<Value, 4> fastBufMapOperands;
1970   AffineForOp copyNestRoot;
1971   SmallVector<AffineApplyOp, 4> mayBeDeadApplys;
1972   for (unsigned d = 0; d < rank; ++d) {
1973     auto forOp = createCanonicalizedAffineForOp(b, loc, lbOperands, lbMaps[d],
1974                                                 ubOperands, ubMaps[d]);
1975     if (d == 0)
1976       copyNestRoot = forOp;
1977 
1978     b = OpBuilder::atBlockTerminator(forOp.getBody());
1979 
1980     auto fastBufOffsetMap =
1981         AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]);
1982     auto offset = b.create<AffineApplyOp>(loc, fastBufOffsetMap, lbOperands);
1983 
1984     // Construct the subscript for the fast memref being copied into/from:
1985     // x - offset_x.
1986     fastBufExprs.push_back(b.getAffineDimExpr(2 * d + 1) -
1987                            b.getAffineDimExpr(2 * d));
1988     fastBufMapOperands.push_back(offset);
1989     fastBufMapOperands.push_back(forOp.getInductionVar());
1990     mayBeDeadApplys.push_back(offset);
1991 
1992     // Subscript for the slow memref being copied.
1993     memIndices.push_back(forOp.getInductionVar());
1994   }
1995 
1996   auto fastBufMap =
1997       AffineMap::get(2 * rank, /*symbolCount=*/0, fastBufExprs, b.getContext());
1998   fullyComposeAffineMapAndOperands(&fastBufMap, &fastBufMapOperands);
1999   fastBufMap = simplifyAffineMap(fastBufMap);
2000   canonicalizeMapAndOperands(&fastBufMap, &fastBufMapOperands);
2001 
2002   // Drop any dead affine.applys.
2003   for (auto applyOp : mayBeDeadApplys)
2004     if (applyOp.use_empty())
2005       applyOp.erase();
2006 
2007   if (!isCopyOut) {
2008     // Copy in.
2009     auto load = b.create<AffineLoadOp>(loc, memref, memIndices);
2010     b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufMap,
2011                             fastBufMapOperands);
2012     return copyNestRoot;
2013   }
2014 
2015   // Copy out.
2016   auto load =
2017       b.create<AffineLoadOp>(loc, fastMemRef, fastBufMap, fastBufMapOperands);
2018   b.create<AffineStoreOp>(loc, load, memref, memIndices);
2019   return copyNestRoot;
2020 }
2021 
2022 static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
2023 emitRemarkForBlock(Block &block) {
2024   return block.getParentOp()->emitRemark();
2025 }
2026 
2027 /// Creates a buffer in the faster memory space for the specified memref region;
2028 /// generates a copy from the lower memory space to this one, and replaces all
2029 /// loads/stores in the block range [`begin', `end') of `block' to load/store
2030 /// from that buffer. Returns failure if copies could not be generated due to
2031 /// yet unimplemented cases. `copyInPlacementStart` and `copyOutPlacementStart`
2032 /// in copyPlacementBlock specify the insertion points where the incoming copies
2033 /// and outgoing copies, respectively, should be inserted (the insertion happens
2034 /// right before the insertion point). Since `begin` can itself be invalidated
2035 /// due to the memref rewriting done from this method, the output argument
2036 /// `nBegin` is set to its replacement (set to `begin` if no invalidation
2037 /// happens). Since outgoing copies could have  been inserted at `end`, the
2038 /// output argument `nEnd` is set to the new end. `sizeInBytes` is set to the
2039 /// size of the fast buffer allocated.
2040 static LogicalResult generateCopy(
2041     const MemRefRegion &region, Block *block, Block::iterator begin,
2042     Block::iterator end, Block *copyPlacementBlock,
2043     Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart,
2044     AffineCopyOptions copyOptions, DenseMap<Value, Value> &fastBufferMap,
2045     DenseSet<Operation *> &copyNests, uint64_t *sizeInBytes,
2046     Block::iterator *nBegin, Block::iterator *nEnd) {
2047   *nBegin = begin;
2048   *nEnd = end;
2049 
2050   FuncOp f = begin->getParentOfType<FuncOp>();
2051   OpBuilder topBuilder(f.getBody());
2052   Value zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0);
2053 
2054   if (begin == end)
2055     return success();
2056 
2057   // Is the copy out point at the end of the block where we are doing
2058   // explicit copying.
2059   bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart);
2060 
2061   // Copies for read regions are going to be inserted at 'begin'.
2062   OpBuilder prologue(copyPlacementBlock, copyInPlacementStart);
2063   // Copies for write regions are going to be inserted at 'end'.
2064   OpBuilder epilogue(copyPlacementBlock, copyOutPlacementStart);
2065   OpBuilder &b = region.isWrite() ? epilogue : prologue;
2066 
2067   // Builder to create constants at the top level.
2068   auto func = copyPlacementBlock->getParent()->getParentOfType<FuncOp>();
2069   OpBuilder top(func.getBody());
2070 
2071   auto loc = region.loc;
2072   auto memref = region.memref;
2073   auto memRefType = memref.getType().cast<MemRefType>();
2074 
2075   if (!memRefType.getLayout().isIdentity()) {
2076     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
2077     return failure();
2078   }
2079 
2080   // Indices to use for the copying.
2081   // Indices for the original memref being copied from/to.
2082   SmallVector<Value, 4> memIndices;
2083   // Indices for the faster buffer being copied into/from.
2084   SmallVector<Value, 4> bufIndices;
2085 
2086   unsigned rank = memRefType.getRank();
2087   SmallVector<int64_t, 4> fastBufferShape;
2088 
2089   // Compute the extents of the buffer.
2090   std::vector<SmallVector<int64_t, 4>> lbs;
2091   SmallVector<int64_t, 8> lbDivisors;
2092   lbs.reserve(rank);
2093   Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(
2094       &fastBufferShape, &lbs, &lbDivisors);
2095   if (!numElements.hasValue()) {
2096     LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
2097     return failure();
2098   }
2099 
2100   if (numElements.getValue() == 0) {
2101     LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
2102     *sizeInBytes = 0;
2103     return success();
2104   }
2105 
2106   SmallVector<AffineMap, 4> lbMaps(rank), ubMaps(rank);
2107   for (unsigned i = 0; i < rank; ++i)
2108     region.getLowerAndUpperBound(i, lbMaps[i], ubMaps[i]);
2109 
2110   const FlatAffineValueConstraints *cst = region.getConstraints();
2111   // 'regionSymbols' hold values that this memory region is symbolic/parametric
2112   // on; these typically include loop IVs surrounding the level at which the
2113   // copy generation is being done or other valid symbols in MLIR.
2114   SmallVector<Value, 8> regionSymbols;
2115   cst->getValues(rank, cst->getNumIds(), &regionSymbols);
2116 
2117   // Construct the index expressions for the fast memory buffer. The index
2118   // expression for a particular dimension of the fast buffer is obtained by
2119   // subtracting out the lower bound on the original memref's data region
2120   // along the corresponding dimension.
2121 
2122   // Index start offsets for faster memory buffer relative to the original.
2123   SmallVector<AffineExpr, 4> fastBufOffsets;
2124   fastBufOffsets.reserve(rank);
2125   for (unsigned d = 0; d < rank; d++) {
2126     assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
2127 
2128     AffineExpr offset = top.getAffineConstantExpr(0);
2129     for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++)
2130       offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
2131     assert(lbDivisors[d] > 0);
2132     offset =
2133         (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
2134 
2135     // Set copy start location for this dimension in the lower memory space
2136     // memref.
2137     if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
2138       auto indexVal = caf.getValue();
2139       if (indexVal == 0) {
2140         memIndices.push_back(zeroIndex);
2141       } else {
2142         memIndices.push_back(
2143             top.create<arith::ConstantIndexOp>(loc, indexVal).getResult());
2144       }
2145     } else {
2146       // The coordinate for the start location is just the lower bound along the
2147       // corresponding dimension on the memory region (stored in 'offset').
2148       auto map = AffineMap::get(
2149           cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset);
2150       memIndices.push_back(b.create<AffineApplyOp>(loc, map, regionSymbols));
2151     }
2152     // The fast buffer is copied into at location zero; addressing is relative.
2153     bufIndices.push_back(zeroIndex);
2154 
2155     // Record the offsets since they are needed to remap the memory accesses of
2156     // the original memref further below.
2157     fastBufOffsets.push_back(offset);
2158   }
2159 
2160   // The faster memory space buffer.
2161   Value fastMemRef;
2162 
2163   // Check if a buffer was already created.
2164   bool existingBuf = fastBufferMap.count(memref) > 0;
2165   if (!existingBuf) {
2166     AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank);
2167     auto fastMemRefType =
2168         MemRefType::get(fastBufferShape, memRefType.getElementType(),
2169                         fastBufferLayout, copyOptions.fastMemorySpace);
2170 
2171     // Create the fast memory space buffer just before the 'affine.for'
2172     // operation.
2173     fastMemRef =
2174         prologue.create<memref::AllocOp>(loc, fastMemRefType).getResult();
2175     // Record it.
2176     fastBufferMap[memref] = fastMemRef;
2177     // fastMemRefType is a constant shaped memref.
2178     *sizeInBytes = getMemRefSizeInBytes(fastMemRefType).getValue();
2179     LLVM_DEBUG(emitRemarkForBlock(*block)
2180                << "Creating fast buffer of type " << fastMemRefType
2181                << " and size " << llvm::divideCeil(*sizeInBytes, 1024)
2182                << " KiB\n");
2183   } else {
2184     // Reuse the one already created.
2185     fastMemRef = fastBufferMap[memref];
2186     *sizeInBytes = 0;
2187   }
2188 
2189   auto numElementsSSA =
2190       top.create<arith::ConstantIndexOp>(loc, numElements.getValue());
2191 
2192   Value dmaStride = nullptr;
2193   Value numEltPerDmaStride = nullptr;
2194   if (copyOptions.generateDma) {
2195     SmallVector<StrideInfo, 4> dmaStrideInfos;
2196     getMultiLevelStrides(region, fastBufferShape, &dmaStrideInfos);
2197 
2198     // TODO: use all stride levels once DmaStartOp is extended for
2199     // multi-level strides.
2200     if (dmaStrideInfos.size() > 1) {
2201       LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
2202       return failure();
2203     }
2204 
2205     if (!dmaStrideInfos.empty()) {
2206       dmaStride =
2207           top.create<arith::ConstantIndexOp>(loc, dmaStrideInfos[0].stride);
2208       numEltPerDmaStride = top.create<arith::ConstantIndexOp>(
2209           loc, dmaStrideInfos[0].numEltPerStride);
2210     }
2211   }
2212 
2213   // Record the last operation where we want the memref replacement to end. We
2214   // later do the memref replacement only in [begin, postDomFilter] so
2215   // that the original memref's used in the data movement code themselves don't
2216   // get replaced.
2217   auto postDomFilter = std::prev(end);
2218 
2219   // Create fully composed affine maps for each memref.
2220   auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
2221   fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
2222   auto bufAffineMap = b.getMultiDimIdentityMap(bufIndices.size());
2223   fullyComposeAffineMapAndOperands(&bufAffineMap, &bufIndices);
2224 
2225   if (!copyOptions.generateDma) {
2226     // Point-wise copy generation.
2227     auto copyNest =
2228         generatePointWiseCopy(loc, memref, fastMemRef, lbMaps,
2229                               /*lbOperands=*/regionSymbols, ubMaps,
2230                               /*ubOperands=*/regionSymbols, fastBufOffsets,
2231                               /*isCopyOut=*/region.isWrite(), b);
2232 
2233     // Record this so that we can skip it from yet another copy.
2234     copyNests.insert(copyNest);
2235 
2236     // Since new ops are being appended (for copy out's), adjust the end to
2237     // mark end of block range being processed if necessary.
2238     if (region.isWrite() && isCopyOutAtEndOfBlock)
2239       *nEnd = Block::iterator(copyNest.getOperation());
2240   } else {
2241     // DMA generation.
2242     // Create a tag (single element 1-d memref) for the DMA.
2243     auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {},
2244                                          copyOptions.tagMemorySpace);
2245     auto tagMemRef = prologue.create<memref::AllocOp>(loc, tagMemRefType);
2246 
2247     SmallVector<Value, 4> tagIndices({zeroIndex});
2248     auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
2249     fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
2250     if (!region.isWrite()) {
2251       // DMA non-blocking read from original buffer to fast buffer.
2252       b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
2253                                  fastMemRef, bufAffineMap, bufIndices,
2254                                  tagMemRef, tagAffineMap, tagIndices,
2255                                  numElementsSSA, dmaStride, numEltPerDmaStride);
2256     } else {
2257       // DMA non-blocking write from fast buffer to the original memref.
2258       auto op = b.create<AffineDmaStartOp>(
2259           loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
2260           memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
2261           dmaStride, numEltPerDmaStride);
2262       // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
2263       // end to mark end of block range being processed.
2264       if (isCopyOutAtEndOfBlock)
2265         *nEnd = Block::iterator(op.getOperation());
2266     }
2267 
2268     // Matching DMA wait to block on completion; tag always has a 0 index.
2269     b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex,
2270                               numElementsSSA);
2271 
2272     // Generate dealloc for the tag.
2273     auto tagDeallocOp = epilogue.create<memref::DeallocOp>(loc, tagMemRef);
2274     if (*nEnd == end && isCopyOutAtEndOfBlock)
2275       // Since new ops are being appended (for outgoing DMAs), adjust the end to
2276       // mark end of range of the original.
2277       *nEnd = Block::iterator(tagDeallocOp.getOperation());
2278   }
2279 
2280   // Generate dealloc for the buffer.
2281   if (!existingBuf) {
2282     auto bufDeallocOp = epilogue.create<memref::DeallocOp>(loc, fastMemRef);
2283     // When generating pointwise copies, `nEnd' has to be set to deallocOp on
2284     // the fast buffer (since it marks the new end insertion point).
2285     if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock)
2286       *nEnd = Block::iterator(bufDeallocOp.getOperation());
2287   }
2288 
2289   // Replace all uses of the old memref with the faster one while remapping
2290   // access indices (subtracting out lower bound offsets for each dimension).
2291   // Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT],
2292   // index remap will be (%i, %j) -> (%i - %iT, %j - %jT),
2293   // i.e., affine.apply (d0, d1, d2, d3) -> (d2-d0, d3-d1) (%iT, %jT, %i, %j),
2294   // and (%iT, %jT) will be the 'extraOperands' for 'rep all memref uses with'.
2295   // d2, d3 correspond to the original indices (%i, %j).
2296   SmallVector<AffineExpr, 4> remapExprs;
2297   remapExprs.reserve(rank);
2298   for (unsigned i = 0; i < rank; i++) {
2299     // The starting operands of indexRemap will be regionSymbols (the symbols on
2300     // which the memref region is parametric); then those corresponding to
2301     // the memref's original indices follow.
2302     auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i);
2303     remapExprs.push_back(dimExpr - fastBufOffsets[i]);
2304   }
2305   auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs,
2306                                    b.getContext());
2307 
2308   // Record the begin since it may be invalidated by memref replacement.
2309   Block::iterator prevOfBegin;
2310   bool isBeginAtStartOfBlock = (begin == block->begin());
2311   if (!isBeginAtStartOfBlock)
2312     prevOfBegin = std::prev(begin);
2313 
2314   // *Only* those uses within the range [begin, end) of 'block' are replaced.
2315   (void)replaceAllMemRefUsesWith(memref, fastMemRef,
2316                                  /*extraIndices=*/{}, indexRemap,
2317                                  /*extraOperands=*/regionSymbols,
2318                                  /*symbolOperands=*/{},
2319                                  /*domOpFilter=*/&*begin,
2320                                  /*postDomOpFilter=*/&*postDomFilter);
2321 
2322   *nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);
2323 
2324   return success();
2325 }
2326 
2327 /// Construct the memref region to just include the entire memref. Returns false
2328 /// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
2329 /// enclosing loop IVs of `op` (starting from the outermost) that the region
2330 /// is parametric on.
2331 static bool getFullMemRefAsRegion(Operation *op, unsigned numParamLoopIVs,
2332                                   MemRefRegion *region) {
2333   unsigned rank;
2334   if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
2335     rank = loadOp.getMemRefType().getRank();
2336     region->memref = loadOp.getMemRef();
2337     region->setWrite(false);
2338   } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
2339     rank = storeOp.getMemRefType().getRank();
2340     region->memref = storeOp.getMemRef();
2341     region->setWrite(true);
2342   } else {
2343     assert(false && "expected load or store op");
2344     return false;
2345   }
2346   auto memRefType = region->memref.getType().cast<MemRefType>();
2347   if (!memRefType.hasStaticShape())
2348     return false;
2349 
2350   auto *regionCst = region->getConstraints();
2351 
2352   // Just get the first numSymbols IVs, which the memref region is parametric
2353   // on.
2354   SmallVector<AffineForOp, 4> ivs;
2355   getLoopIVs(*op, &ivs);
2356   ivs.resize(numParamLoopIVs);
2357   SmallVector<Value, 4> symbols;
2358   extractForInductionVars(ivs, &symbols);
2359   regionCst->reset(rank, numParamLoopIVs, 0);
2360   regionCst->setValues(rank, rank + numParamLoopIVs, symbols);
2361 
2362   // Memref dim sizes provide the bounds.
2363   for (unsigned d = 0; d < rank; d++) {
2364     auto dimSize = memRefType.getDimSize(d);
2365     assert(dimSize > 0 && "filtered dynamic shapes above");
2366     regionCst->addBound(FlatAffineConstraints::LB, d, 0);
2367     regionCst->addBound(FlatAffineConstraints::UB, d, dimSize - 1);
2368   }
2369   return true;
2370 }
2371 
2372 LogicalResult mlir::affineDataCopyGenerate(Block::iterator begin,
2373                                            Block::iterator end,
2374                                            const AffineCopyOptions &copyOptions,
2375                                            Optional<Value> filterMemRef,
2376                                            DenseSet<Operation *> &copyNests) {
2377   if (begin == end)
2378     return success();
2379 
2380   assert(begin->getBlock() == std::prev(end)->getBlock() &&
2381          "Inconsistent block begin/end args");
2382   assert(end != end->getBlock()->end() && "end can't be the block terminator");
2383 
2384   Block *block = begin->getBlock();
2385 
2386   // Copies will be generated for this depth, i.e., symbolic in all loops
2387   // surrounding the this block range.
2388   unsigned copyDepth = getNestingDepth(&*begin);
2389 
2390   LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
2391                           << "\n");
2392   LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n");
2393   LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n");
2394 
2395   // List of memory regions to copy for. We need a map vector to have a
2396   // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
2397   // since the alloc's for example are identical except for the SSA id.
2398   SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> readRegions;
2399   SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> writeRegions;
2400 
2401   // Map from original memref's to the fast buffers that their accesses are
2402   // replaced with.
2403   DenseMap<Value, Value> fastBufferMap;
2404 
2405   // To check for errors when walking the block.
2406   bool error = false;
2407 
2408   // Walk this range of operations  to gather all memory regions.
2409   block->walk(begin, end, [&](Operation *opInst) {
2410     // Gather regions to allocate to buffers in faster memory space.
2411     if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
2412       if ((filterMemRef.hasValue() && filterMemRef != loadOp.getMemRef()) ||
2413           (loadOp.getMemRefType().getMemorySpaceAsInt() !=
2414            copyOptions.slowMemorySpace))
2415         return;
2416     } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
2417       if ((filterMemRef.hasValue() && filterMemRef != storeOp.getMemRef()) ||
2418           storeOp.getMemRefType().getMemorySpaceAsInt() !=
2419               copyOptions.slowMemorySpace)
2420         return;
2421     } else {
2422       // Neither load nor a store op.
2423       return;
2424     }
2425 
2426     // Compute the MemRefRegion accessed.
2427     auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
2428     if (failed(region->compute(opInst, copyDepth, /*sliceState=*/nullptr,
2429                                /*addMemRefDimBounds=*/false))) {
2430       LLVM_DEBUG(llvm::dbgs()
2431                  << "Error obtaining memory region: semi-affine maps?\n");
2432       LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
2433       if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
2434         LLVM_DEBUG(
2435             opInst->emitError("non-constant memref sizes not yet supported"));
2436         error = true;
2437         return;
2438       }
2439     }
2440 
2441     // Each memref has a single buffer associated with it irrespective of how
2442     // many load's and store's happen on it.
2443     // TODO: in the future, when regions don't intersect and satisfy
2444     // other properties (based on load/store regions), we could consider
2445     // multiple buffers per memref.
2446 
2447     // Add to the appropriate region if it's not already in it, or take a
2448     // bounding box union with the existing one if it's already in there.
2449     // Note that a memref may have both read and write regions - so update the
2450     // region in the other list if one exists (write in case of read and vice
2451     // versa) since there is a single bounding box for a memref across all reads
2452     // and writes that happen on it.
2453 
2454     // Attempts to update; returns true if 'region' exists in targetRegions.
2455     auto updateRegion =
2456         [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
2457                 &targetRegions) {
2458           const auto *const it = targetRegions.find(region->memref);
2459           if (it == targetRegions.end())
2460             return false;
2461 
2462           // Perform a union with the existing region.
2463           if (failed(it->second->unionBoundingBox(*region))) {
2464             LLVM_DEBUG(llvm::dbgs()
2465                        << "Memory region bounding box failed; "
2466                           "over-approximating to the entire memref\n");
2467             // If the union fails, we will overapproximate.
2468             if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
2469               LLVM_DEBUG(opInst->emitError(
2470                   "non-constant memref sizes not yet supported"));
2471               error = true;
2472               return true;
2473             }
2474             it->second->getConstraints()->clearAndCopyFrom(
2475                 *region->getConstraints());
2476           } else {
2477             // Union was computed and stored in 'it->second': copy to 'region'.
2478             region->getConstraints()->clearAndCopyFrom(
2479                 *it->second->getConstraints());
2480           }
2481           return true;
2482         };
2483 
2484     bool existsInRead = updateRegion(readRegions);
2485     if (error)
2486       return;
2487     bool existsInWrite = updateRegion(writeRegions);
2488     if (error)
2489       return;
2490 
2491     // Finally add it to the region list.
2492     if (region->isWrite() && !existsInWrite) {
2493       writeRegions[region->memref] = std::move(region);
2494     } else if (!region->isWrite() && !existsInRead) {
2495       readRegions[region->memref] = std::move(region);
2496     }
2497   });
2498 
2499   if (error) {
2500     LLVM_DEBUG(begin->emitError(
2501         "copy generation failed for one or more memref's in this block\n"));
2502     return failure();
2503   }
2504 
2505   uint64_t totalCopyBuffersSizeInBytes = 0;
2506   bool ret = true;
2507   auto processRegions =
2508       [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
2509               &regions) {
2510         for (const auto &regionEntry : regions) {
2511           // For each region, hoist copy in/out past all hoistable
2512           // 'affine.for's.
2513           Block::iterator copyInPlacementStart, copyOutPlacementStart;
2514           Block *copyPlacementBlock;
2515           findHighestBlockForPlacement(
2516               *regionEntry.second, *block, begin, end, &copyPlacementBlock,
2517               &copyInPlacementStart, &copyOutPlacementStart);
2518 
2519           uint64_t sizeInBytes;
2520           Block::iterator nBegin, nEnd;
2521           LogicalResult iRet = generateCopy(
2522               *regionEntry.second, block, begin, end, copyPlacementBlock,
2523               copyInPlacementStart, copyOutPlacementStart, copyOptions,
2524               fastBufferMap, copyNests, &sizeInBytes, &nBegin, &nEnd);
2525           if (succeeded(iRet)) {
2526             // begin/end could have been invalidated, and need update.
2527             begin = nBegin;
2528             end = nEnd;
2529             totalCopyBuffersSizeInBytes += sizeInBytes;
2530           }
2531           ret = ret & succeeded(iRet);
2532         }
2533       };
2534   processRegions(readRegions);
2535   processRegions(writeRegions);
2536 
2537   if (!ret) {
2538     LLVM_DEBUG(begin->emitError(
2539         "copy generation failed for one or more memref's in this block\n"));
2540     return failure();
2541   }
2542 
2543   // For a range of operations, a note will be emitted at the caller.
2544   AffineForOp forOp;
2545   if (llvm::DebugFlag && (forOp = dyn_cast<AffineForOp>(&*begin))) {
2546     LLVM_DEBUG(forOp.emitRemark()
2547                << llvm::divideCeil(totalCopyBuffersSizeInBytes, 1024)
2548                << " KiB of copy buffers in fast memory space for this block\n");
2549   }
2550 
2551   if (totalCopyBuffersSizeInBytes > copyOptions.fastMemCapacityBytes) {
2552     StringRef str = "Total size of all copy buffers' for this block "
2553                     "exceeds fast memory capacity\n";
2554     block->getParentOp()->emitWarning(str);
2555   }
2556 
2557   return success();
2558 }
2559 
2560 // A convenience version of affineDataCopyGenerate for all ops in the body of
2561 // an AffineForOp.
2562 LogicalResult mlir::affineDataCopyGenerate(AffineForOp forOp,
2563                                            const AffineCopyOptions &copyOptions,
2564                                            Optional<Value> filterMemRef,
2565                                            DenseSet<Operation *> &copyNests) {
2566   return affineDataCopyGenerate(forOp.getBody()->begin(),
2567                                 std::prev(forOp.getBody()->end()), copyOptions,
2568                                 filterMemRef, copyNests);
2569 }
2570 
2571 LogicalResult mlir::generateCopyForMemRegion(
2572     const MemRefRegion &memrefRegion, Operation *analyzedOp,
2573     const AffineCopyOptions &copyOptions, CopyGenerateResult &result) {
2574   Block *block = analyzedOp->getBlock();
2575   auto begin = analyzedOp->getIterator();
2576   auto end = std::next(begin);
2577   DenseMap<Value, Value> fastBufferMap;
2578   DenseSet<Operation *> copyNests;
2579 
2580   auto err = generateCopy(memrefRegion, block, begin, end, block, begin, end,
2581                           copyOptions, fastBufferMap, copyNests,
2582                           &result.sizeInBytes, &begin, &end);
2583   if (failed(err))
2584     return err;
2585 
2586   const auto &en = fastBufferMap.find(memrefRegion.memref);
2587   // In some cases (empty loops), no copy generation would have happened.
2588   if (en == fastBufferMap.end())
2589     return failure();
2590   result.alloc = en->second.getDefiningOp();
2591   assert(result.alloc && "fast buffer expected to be locally allocated");
2592   assert(copyNests.size() <= 1 && "At most one copy nest is expected.");
2593   result.copyNest = copyNests.empty() ? nullptr : *copyNests.begin();
2594   return success();
2595 }
2596 
2597 /// Gathers all AffineForOps in 'block' at 'currLoopDepth' in 'depthToLoops'.
2598 static void
2599 gatherLoopsInBlock(Block *block, unsigned currLoopDepth,
2600                    std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
2601   // Add a new empty level to output if it doesn't exist level already.
2602   assert(currLoopDepth <= depthToLoops.size() && "Unexpected currLoopDepth");
2603   if (currLoopDepth == depthToLoops.size())
2604     depthToLoops.emplace_back();
2605 
2606   for (auto &op : *block) {
2607     if (auto forOp = dyn_cast<AffineForOp>(op)) {
2608       depthToLoops[currLoopDepth].push_back(forOp);
2609       gatherLoopsInBlock(forOp.getBody(), currLoopDepth + 1, depthToLoops);
2610     }
2611   }
2612 }
2613 
2614 /// Gathers all AffineForOps in 'builtin.func' grouped by loop depth.
2615 void mlir::gatherLoops(FuncOp func,
2616                        std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
2617   for (auto &block : func)
2618     gatherLoopsInBlock(&block, /*currLoopDepth=*/0, depthToLoops);
2619 
2620   // Remove last loop level from output since it's empty.
2621   if (!depthToLoops.empty()) {
2622     assert(depthToLoops.back().empty() && "Last loop level is not empty?");
2623     depthToLoops.pop_back();
2624   }
2625 }
2626 
2627 // TODO: if necessary, this can be extended to also compose in any
2628 // affine.applys, fold to constant if all result dimensions of the map are
2629 // constant (canonicalizeMapAndOperands below already does this for single
2630 // result bound maps), and use simplifyMap to perform algebraic simplification.
2631 AffineForOp mlir::createCanonicalizedAffineForOp(
2632     OpBuilder b, Location loc, ValueRange lbOperands, AffineMap lbMap,
2633     ValueRange ubOperands, AffineMap ubMap, int64_t step) {
2634   SmallVector<Value, 4> lowerOperands(lbOperands);
2635   SmallVector<Value, 4> upperOperands(ubOperands);
2636 
2637   fullyComposeAffineMapAndOperands(&lbMap, &lowerOperands);
2638   canonicalizeMapAndOperands(&lbMap, &lowerOperands);
2639   lbMap = removeDuplicateExprs(lbMap);
2640   fullyComposeAffineMapAndOperands(&ubMap, &upperOperands);
2641   canonicalizeMapAndOperands(&ubMap, &upperOperands);
2642   ubMap = removeDuplicateExprs(ubMap);
2643 
2644   return b.create<AffineForOp>(loc, lowerOperands, lbMap, upperOperands, ubMap,
2645                                step);
2646 }
2647 
2648 /// Creates an AffineIfOp that encodes the conditional to choose between
2649 /// the constant trip count version and an unknown trip count version of this
2650 /// nest of loops. This is used to separate partial and full tiles if `loops`
2651 /// has the intra-tile loops. The affine.if op is inserted at the builder
2652 /// insertion point of `b`.
2653 static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
2654                                             OpBuilder b) {
2655   if (loops.empty())
2656     return nullptr;
2657 
2658   auto *context = loops[0].getContext();
2659 
2660   FlatAffineValueConstraints cst;
2661   SmallVector<Operation *, 8> ops;
2662   ops.reserve(loops.size());
2663   for (AffineForOp forOp : loops)
2664     ops.push_back(forOp);
2665   (void)getIndexSet(ops, &cst);
2666 
2667   // Remove constraints that are independent of these loop IVs.
2668   cst.removeIndependentConstraints(/*pos=*/0, /*num=*/loops.size());
2669 
2670   // Construct the constraint set representing the guard for full tiles. The
2671   // lower bound (and upper bound) corresponding to the full tile should be
2672   // larger (and resp. smaller) than any other lower (or upper bound).
2673   SmallVector<int64_t, 8> fullTileLb, fullTileUb;
2674   for (auto loop : loops) {
2675     (void)loop;
2676     // TODO: Non-unit stride is not an issue to generalize to.
2677     assert(loop.getStep() == 1 && "point loop step expected to be one");
2678     // Mark everything symbols for the purpose of finding a constant diff pair.
2679     cst.setDimSymbolSeparation(/*newSymbolCount=*/cst.getNumDimAndSymbolIds() -
2680                                1);
2681     unsigned fullTileLbPos, fullTileUbPos;
2682     if (!cst.getConstantBoundOnDimSize(0, /*lb=*/nullptr,
2683                                        /*boundFloorDivisor=*/nullptr,
2684                                        /*ub=*/nullptr, &fullTileLbPos,
2685                                        &fullTileUbPos)) {
2686       LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n");
2687       return nullptr;
2688     }
2689 
2690     SmallVector<unsigned, 4> lbIndices, ubIndices;
2691     cst.getLowerAndUpperBoundIndices(/*pos=*/0, &lbIndices, &ubIndices);
2692 
2693     auto fLb = cst.getInequality(fullTileLbPos);
2694     auto fUb = cst.getInequality(fullTileUbPos);
2695     fullTileLb.assign(fLb.begin(), fLb.end());
2696     fullTileUb.assign(fUb.begin(), fUb.end());
2697 
2698     // Full tile lower bound should be >= than any other lower bound.
2699     for (auto lbIndex : lbIndices)
2700       for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i)
2701         cst.atIneq(lbIndex, i) = fullTileLb[i] - cst.atIneq(lbIndex, i);
2702 
2703     // Full tile upper bound should be <= any other upper bound.
2704     for (auto ubIndex : ubIndices)
2705       for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i)
2706         cst.atIneq(ubIndex, i) -= fullTileUb[i];
2707 
2708     cst.removeId(0);
2709   }
2710 
2711   // The previous step leads to all zeros for the full tile lb and ub position
2712   // itself; remove those and any other duplicates / trivial redundancies.
2713   cst.removeTrivialRedundancy();
2714 
2715   // Turn everything into dims conservatively since we earlier turned all
2716   // trailing ids past point loop IV into symbols. Some of these could be outer
2717   // loop IVs; we'll canonicalize anyway.
2718   cst.setDimSymbolSeparation(0);
2719 
2720   IntegerSet ifCondSet = cst.getAsIntegerSet(context);
2721   // ifCondSet can be null if cst was empty -- this can happen if all loops
2722   // in the nest have constant trip counts.
2723   if (!ifCondSet)
2724     return nullptr;
2725 
2726   SmallVector<Value, 4> setOperands;
2727   cst.getValues(0, cst.getNumDimAndSymbolIds(), &setOperands);
2728   canonicalizeSetAndOperands(&ifCondSet, &setOperands);
2729   return b.create<AffineIfOp>(loops[0].getLoc(), ifCondSet, setOperands,
2730                               /*withElseRegion=*/true);
2731 }
2732 
2733 /// Create the full tile loop nest (along with its body).
2734 static LogicalResult
2735 createFullTiles(MutableArrayRef<AffineForOp> inputNest,
2736                 SmallVectorImpl<AffineForOp> &fullTileLoops, OpBuilder b) {
2737   fullTileLoops.reserve(inputNest.size());
2738 
2739   // For each loop in the original nest identify a lower/upper bound pair such
2740   // that their difference is a constant.
2741   FlatAffineValueConstraints cst;
2742   for (auto loop : inputNest) {
2743     // TODO: straightforward to generalize to a non-unit stride.
2744     if (loop.getStep() != 1) {
2745       LLVM_DEBUG(llvm::dbgs()
2746                  << "[tile separation] non-unit stride not implemented\n");
2747       return failure();
2748     }
2749     SmallVector<Operation *, 1> loopOp{loop.getOperation()};
2750     (void)getIndexSet(loopOp, &cst);
2751     // We will mark everything other than this loop IV as symbol for getting a
2752     // pair of <lb, ub> with a constant difference.
2753     cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - 1);
2754     unsigned lbPos, ubPos;
2755     if (!cst.getConstantBoundOnDimSize(/*pos=*/0, /*lb=*/nullptr,
2756                                        /*lbDivisor=*/nullptr, /*ub=*/nullptr,
2757                                        &lbPos, &ubPos) ||
2758         lbPos == ubPos) {
2759       LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / "
2760                                  "equalities not yet handled\n");
2761       return failure();
2762     }
2763 
2764     // Set all identifiers as dimensions uniformly since some of those marked as
2765     // symbols above could be outer loop IVs (corresponding tile space IVs).
2766     cst.setDimSymbolSeparation(/*newSymbolCount=*/0);
2767 
2768     AffineValueMap lbVmap, ubVmap;
2769     cst.getIneqAsAffineValueMap(/*pos=*/0, lbPos, lbVmap, b.getContext());
2770     cst.getIneqAsAffineValueMap(/*pos=*/0, ubPos, ubVmap, b.getContext());
2771     AffineForOp fullTileLoop = createCanonicalizedAffineForOp(
2772         b, loop.getLoc(), lbVmap.getOperands(), lbVmap.getAffineMap(),
2773         ubVmap.getOperands(), ubVmap.getAffineMap());
2774     b = OpBuilder::atBlockTerminator(fullTileLoop.getBody());
2775     fullTileLoops.push_back(fullTileLoop);
2776   }
2777 
2778   // Add the body for the full tile loop nest.
2779   BlockAndValueMapping operandMap;
2780   for (const auto &loopEn : llvm::enumerate(inputNest))
2781     operandMap.map(loopEn.value().getInductionVar(),
2782                    fullTileLoops[loopEn.index()].getInductionVar());
2783   b = OpBuilder::atBlockTerminator(fullTileLoops.back().getBody());
2784   for (auto &op : inputNest.back().getBody()->without_terminator())
2785     b.clone(op, operandMap);
2786   return success();
2787 }
2788 
2789 LogicalResult
2790 mlir::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
2791                         SmallVectorImpl<AffineForOp> *fullTileNest) {
2792   if (inputNest.empty())
2793     return success();
2794 
2795   auto firstLoop = inputNest[0];
2796 
2797   // Each successive for op has to be nested in the other.
2798   auto prevLoop = firstLoop;
2799   for (auto loop : inputNest.drop_front(1)) {
2800     assert(loop->getParentOp() == prevLoop && "input not contiguously nested");
2801     prevLoop = loop;
2802   }
2803 
2804   // Create the full tile loop nest.
2805   SmallVector<AffineForOp, 4> fullTileLoops;
2806   OpBuilder b(firstLoop);
2807   if (failed(createFullTiles(inputNest, fullTileLoops, b))) {
2808     if (!fullTileLoops.empty())
2809       fullTileLoops.front().erase();
2810     return failure();
2811   }
2812 
2813   // Create and insert the version select right before the root of the nest.
2814   b = OpBuilder(firstLoop);
2815   AffineIfOp ifOp = createSeparationCondition(inputNest, b);
2816   if (!ifOp) {
2817     fullTileLoops.front().erase();
2818     LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating "
2819                                "separation condition\n");
2820     return failure();
2821   }
2822 
2823   // Move the full tile into the then block.
2824   Block *thenBlock = ifOp.getThenBlock();
2825   AffineForOp outermostFullTileLoop = fullTileLoops[0];
2826   thenBlock->getOperations().splice(
2827       std::prev(thenBlock->end()),
2828       outermostFullTileLoop->getBlock()->getOperations(),
2829       Block::iterator(outermostFullTileLoop));
2830 
2831   // Move the partial tile into the else block. The partial tile is the same as
2832   // the original loop nest.
2833   Block *elseBlock = ifOp.getElseBlock();
2834   elseBlock->getOperations().splice(std::prev(elseBlock->end()),
2835                                     firstLoop->getBlock()->getOperations(),
2836                                     Block::iterator(firstLoop));
2837 
2838   if (fullTileNest)
2839     *fullTileNest = std::move(fullTileLoops);
2840 
2841   return success();
2842 }
2843