1 //===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements miscellaneous loop transformation routines.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/LoopUtils.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/Utils.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
20 #include "mlir/Dialect/Affine/Utils.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/SCF/SCF.h"
24 #include "mlir/IR/BlockAndValueMapping.h"
25 #include "mlir/IR/IntegerSet.h"
26 #include "mlir/Support/MathExtras.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/RegionUtils.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/SmallPtrSet.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 
34 #define DEBUG_TYPE "LoopUtils"
35 
36 using namespace mlir;
37 using llvm::SmallMapVector;
38 
39 namespace {
40 // This structure is to pass and return sets of loop parameters without
41 // confusing the order.
42 struct LoopParams {
43   Value lowerBound;
44   Value upperBound;
45   Value step;
46 };
47 } // namespace
48 
49 /// Computes the cleanup loop lower bound of the loop being unrolled with
50 /// the specified unroll factor; this bound will also be upper bound of the main
51 /// part of the unrolled loop. Computes the bound as an AffineMap with its
52 /// operands or a null map when the trip count can't be expressed as an affine
53 /// expression.
54 static void
55 getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
56                          AffineMap &cleanupLbMap,
57                          SmallVectorImpl<Value> &cleanupLbOperands) {
58   AffineMap tripCountMap;
59   SmallVector<Value, 4> tripCountOperands;
60   getTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands);
61   // Trip count can't be computed.
62   if (!tripCountMap) {
63     cleanupLbMap = AffineMap();
64     return;
65   }
66 
67   OpBuilder b(forOp);
68   auto lbMap = forOp.getLowerBoundMap();
69   auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap,
70                                     forOp.getLowerBoundOperands());
71 
72   // For each upper bound expr, get the range.
73   // Eg: affine.for %i = lb to min (ub1, ub2),
74   // where tripCountExprs yield (tr1, tr2), we create affine.apply's:
75   // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all
76   // these affine.apply's make up the cleanup loop lower bound.
77   SmallVector<AffineExpr, 4> bumpExprs(tripCountMap.getNumResults());
78   SmallVector<Value, 4> bumpValues(tripCountMap.getNumResults());
79   int64_t step = forOp.getStep();
80   for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) {
81     auto tripCountExpr = tripCountMap.getResult(i);
82     bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step;
83     auto bumpMap = AffineMap::get(tripCountMap.getNumDims(),
84                                   tripCountMap.getNumSymbols(), bumpExprs[i]);
85     bumpValues[i] =
86         b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands);
87   }
88 
89   SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults());
90   for (unsigned i = 0, e = bumpExprs.size(); i < e; i++)
91     newUbExprs[i] = b.getAffineDimExpr(0) + b.getAffineDimExpr(i + 1);
92 
93   cleanupLbOperands.clear();
94   cleanupLbOperands.push_back(lb);
95   cleanupLbOperands.append(bumpValues.begin(), bumpValues.end());
96   cleanupLbMap = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs,
97                                 b.getContext());
98   // Simplify the cleanupLbMap + cleanupLbOperands.
99   fullyComposeAffineMapAndOperands(&cleanupLbMap, &cleanupLbOperands);
100   cleanupLbMap = simplifyAffineMap(cleanupLbMap);
101   canonicalizeMapAndOperands(&cleanupLbMap, &cleanupLbOperands);
102   // Remove any affine.apply's that became dead from the simplification above.
103   for (auto v : bumpValues)
104     if (v.use_empty())
105       v.getDefiningOp()->erase();
106 
107   if (lb.use_empty())
108     lb.erase();
109 }
110 
111 /// Helper to replace uses of loop carried values (iter_args) and loop
112 /// yield values while promoting single iteration affine.for ops.
113 static void replaceIterArgsAndYieldResults(AffineForOp forOp) {
114   // Replace uses of iter arguments with iter operands (initial values).
115   auto iterOperands = forOp.getIterOperands();
116   auto iterArgs = forOp.getRegionIterArgs();
117   for (auto e : llvm::zip(iterOperands, iterArgs))
118     std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
119 
120   // Replace uses of loop results with the values yielded by the loop.
121   auto outerResults = forOp.getResults();
122   auto innerResults = forOp.getBody()->getTerminator()->getOperands();
123   for (auto e : llvm::zip(outerResults, innerResults))
124     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
125 }
126 
127 /// Promotes the loop body of a forOp to its containing block if the forOp
128 /// was known to have a single iteration.
129 // TODO: extend this for arbitrary affine bounds.
130 LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
131   Optional<uint64_t> tripCount = getConstantTripCount(forOp);
132   if (!tripCount || tripCount.getValue() != 1)
133     return failure();
134 
135   if (forOp.getLowerBoundMap().getNumResults() != 1)
136     return failure();
137 
138   // Replaces all IV uses to its single iteration value.
139   auto iv = forOp.getInductionVar();
140   auto *parentBlock = forOp->getBlock();
141   if (!iv.use_empty()) {
142     if (forOp.hasConstantLowerBound()) {
143       OpBuilder topBuilder(forOp->getParentOfType<FuncOp>().getBody());
144       auto constOp = topBuilder.create<arith::ConstantIndexOp>(
145           forOp.getLoc(), forOp.getConstantLowerBound());
146       iv.replaceAllUsesWith(constOp);
147     } else {
148       auto lbOperands = forOp.getLowerBoundOperands();
149       auto lbMap = forOp.getLowerBoundMap();
150       OpBuilder builder(forOp);
151       if (lbMap == builder.getDimIdentityMap()) {
152         // No need of generating an affine.apply.
153         iv.replaceAllUsesWith(lbOperands[0]);
154       } else {
155         auto affineApplyOp =
156             builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
157         iv.replaceAllUsesWith(affineApplyOp);
158       }
159     }
160   }
161 
162   replaceIterArgsAndYieldResults(forOp);
163 
164   // Move the loop body operations, except for its terminator, to the loop's
165   // containing block.
166   forOp.getBody()->back().erase();
167   parentBlock->getOperations().splice(Block::iterator(forOp),
168                                       forOp.getBody()->getOperations());
169   forOp.erase();
170   return success();
171 }
172 
173 /// Generates an affine.for op with the specified lower and upper bounds
174 /// while generating the right IV remappings to realize shifts for operations in
175 /// its body. The operations that go into the loop body are specified in
176 /// opGroupQueue starting from the specified offset, and in that order. The
177 /// first element of the pair specifies the shift applied to that group of
178 /// operations; the shift is multiplied by the loop step before being applied.
179 /// Returns nullptr if the generated loop simplifies to a single iteration one.
180 static AffineForOp generateShiftedLoop(
181     AffineMap lbMap, AffineMap ubMap,
182     const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> &opGroupQueue,
183     unsigned offset, AffineForOp srcForOp, OpBuilder b) {
184   auto lbOperands = srcForOp.getLowerBoundOperands();
185   auto ubOperands = srcForOp.getUpperBoundOperands();
186 
187   assert(lbMap.getNumInputs() == lbOperands.size());
188   assert(ubMap.getNumInputs() == ubOperands.size());
189 
190   auto loopChunk = b.create<AffineForOp>(srcForOp.getLoc(), lbOperands, lbMap,
191                                          ubOperands, ubMap, srcForOp.getStep());
192   auto loopChunkIV = loopChunk.getInductionVar();
193   auto srcIV = srcForOp.getInductionVar();
194 
195   BlockAndValueMapping operandMap;
196 
197   auto bodyBuilder = OpBuilder::atBlockTerminator(loopChunk.getBody());
198   for (auto it = opGroupQueue.begin() + offset, e = opGroupQueue.end(); it != e;
199        ++it) {
200     uint64_t shift = it->first;
201     auto ops = it->second;
202     // All 'same shift' operations get added with their operands being
203     // remapped to results of cloned operations, and their IV used remapped.
204     // Generate the remapping if the shift is not zero: remappedIV = newIV -
205     // shift.
206     if (!srcIV.use_empty() && shift != 0) {
207       auto ivRemap = bodyBuilder.create<AffineApplyOp>(
208           srcForOp.getLoc(),
209           bodyBuilder.getSingleDimShiftAffineMap(
210               -static_cast<int64_t>(srcForOp.getStep() * shift)),
211           loopChunkIV);
212       operandMap.map(srcIV, ivRemap);
213     } else {
214       operandMap.map(srcIV, loopChunkIV);
215     }
216     for (auto *op : ops)
217       bodyBuilder.clone(*op, operandMap);
218   };
219   if (succeeded(promoteIfSingleIteration(loopChunk)))
220     return AffineForOp();
221   return loopChunk;
222 }
223 
224 // The skewing of operations with respect to one another can be used for
225 // example to allow overlap of asynchronous operations (such as DMA
226 // communication) with computation, or just relative shifting of operations
227 // for better register reuse, locality or parallelism. As such, the shifts are
228 // typically expected to be at most of the order of the number of operations.
229 // This method should not be used as a substitute for loop distribution/fission.
230 // This method uses an algorithm// in time linear in the number of operations
231 // in the body of the for loop - (using the 'sweep line' paradigm). This method
232 // asserts preservation of SSA dominance. A check for that as well as that for
233 // memory-based dependence preservation check rests with the users of this
234 // method.
235 LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
236                                         ArrayRef<uint64_t> shifts,
237                                         bool unrollPrologueEpilogue) {
238   assert(forOp.getBody()->getOperations().size() == shifts.size() &&
239          "too few/many shifts");
240   if (forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
241     return success();
242 
243   // If the trip counts aren't constant, we would need versioning and
244   // conditional guards (or context information to prevent such versioning). The
245   // better way to pipeline for such loops is to first tile them and extract
246   // constant trip count "full tiles" before applying this.
247   auto mayBeConstTripCount = getConstantTripCount(forOp);
248   if (!mayBeConstTripCount.hasValue()) {
249     LLVM_DEBUG(forOp.emitRemark("non-constant trip count loop not handled"));
250     return success();
251   }
252   uint64_t tripCount = mayBeConstTripCount.getValue();
253 
254   assert(isOpwiseShiftValid(forOp, shifts) &&
255          "shifts will lead to an invalid transformation\n");
256 
257   int64_t step = forOp.getStep();
258 
259   unsigned numChildOps = shifts.size();
260 
261   // Do a linear time (counting) sort for the shifts.
262   uint64_t maxShift = *std::max_element(shifts.begin(), shifts.end());
263   if (maxShift >= numChildOps) {
264     // Large shifts are not the typical use case.
265     forOp.emitWarning("not shifting because shifts are unrealistically large");
266     return success();
267   }
268 
269   // An array of operation groups sorted by shift amount; each group has all
270   // operations with the same shift in the order in which they appear in the
271   // body of the 'affine.for' op.
272   std::vector<std::vector<Operation *>> sortedOpGroups(maxShift + 1);
273   unsigned pos = 0;
274   for (auto &op : forOp.getBody()->without_terminator()) {
275     auto shift = shifts[pos++];
276     sortedOpGroups[shift].push_back(&op);
277   }
278 
279   // Unless the shifts have a specific pattern (which actually would be the
280   // common use case), prologue and epilogue are not meaningfully defined.
281   // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
282   // loop generated as the prologue and the last as epilogue and unroll these
283   // fully.
284   AffineForOp prologue, epilogue;
285 
286   // Do a sweep over the sorted shifts while storing open groups in a
287   // vector, and generating loop portions as necessary during the sweep. A block
288   // of operations is paired with its shift.
289   std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> opGroupQueue;
290 
291   auto origLbMap = forOp.getLowerBoundMap();
292   uint64_t lbShift = 0;
293   OpBuilder b(forOp);
294   for (uint64_t d = 0, e = sortedOpGroups.size(); d < e; ++d) {
295     // If nothing is shifted by d, continue.
296     if (sortedOpGroups[d].empty())
297       continue;
298     if (!opGroupQueue.empty()) {
299       assert(d > 0 &&
300              "Queue expected to be empty when the first block is found");
301       // The interval for which the loop needs to be generated here is:
302       // [lbShift, min(lbShift + tripCount, d)) and the body of the
303       // loop needs to have all operations in opQueue in that order.
304       AffineForOp res;
305       if (lbShift + tripCount * step < d * step) {
306         res = generateShiftedLoop(
307             b.getShiftedAffineMap(origLbMap, lbShift),
308             b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
309             opGroupQueue, /*offset=*/0, forOp, b);
310         // Entire loop for the queued op groups generated, empty it.
311         opGroupQueue.clear();
312         lbShift += tripCount * step;
313       } else {
314         res = generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
315                                   b.getShiftedAffineMap(origLbMap, d),
316                                   opGroupQueue, /*offset=*/0, forOp, b);
317         lbShift = d * step;
318       }
319 
320       if (res) {
321         // Simplify/canonicalize the affine.for.
322         RewritePatternSet patterns(res.getContext());
323         AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
324         bool erased;
325         (void)applyOpPatternsAndFold(res, std::move(patterns), &erased);
326 
327         if (!erased && !prologue)
328           prologue = res;
329         if (!erased)
330           epilogue = res;
331       }
332     } else {
333       // Start of first interval.
334       lbShift = d * step;
335     }
336     // Augment the list of operations that get into the current open interval.
337     opGroupQueue.emplace_back(d, sortedOpGroups[d]);
338   }
339 
340   // Those operations groups left in the queue now need to be processed (FIFO)
341   // and their loops completed.
342   for (unsigned i = 0, e = opGroupQueue.size(); i < e; ++i) {
343     uint64_t ubShift = (opGroupQueue[i].first + tripCount) * step;
344     epilogue = generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
345                                    b.getShiftedAffineMap(origLbMap, ubShift),
346                                    opGroupQueue, /*offset=*/i, forOp, b);
347     lbShift = ubShift;
348     if (!prologue)
349       prologue = epilogue;
350   }
351 
352   // Erase the original for op.
353   forOp.erase();
354 
355   if (unrollPrologueEpilogue && prologue)
356     (void)loopUnrollFull(prologue);
357   if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
358     (void)loopUnrollFull(epilogue);
359 
360   return success();
361 }
362 
363 /// Checks the legality of tiling of a hyper-rectangular loop nest by simply
364 /// checking if there is a 'negative' dependence in the memrefs present in
365 /// the loop nest. If yes then tiling is invalid.
366 static bool
367 checkTilingLegalityImpl(MutableArrayRef<mlir::AffineForOp> origLoops) {
368   assert(!origLoops.empty() && "no original loops provided");
369 
370   // We first find out all dependences we intend to check.
371   SmallVector<Operation *, 8> loadAndStoreOps;
372   origLoops[0]->walk([&](Operation *op) {
373     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
374       loadAndStoreOps.push_back(op);
375   });
376 
377   unsigned numOps = loadAndStoreOps.size();
378   unsigned numLoops = origLoops.size();
379   FlatAffineValueConstraints dependenceConstraints;
380   for (unsigned d = 1; d <= numLoops + 1; ++d) {
381     for (unsigned i = 0; i < numOps; ++i) {
382       Operation *srcOp = loadAndStoreOps[i];
383       MemRefAccess srcAccess(srcOp);
384       for (unsigned j = 0; j < numOps; ++j) {
385         Operation *dstOp = loadAndStoreOps[j];
386         MemRefAccess dstAccess(dstOp);
387 
388         SmallVector<DependenceComponent, 2> depComps;
389         dependenceConstraints.reset();
390         DependenceResult result = checkMemrefAccessDependence(
391             srcAccess, dstAccess, d, &dependenceConstraints, &depComps);
392 
393         // Skip if there is no dependence in this case.
394         if (!hasDependence(result))
395           continue;
396 
397         // Check whether there is any negative direction vector in the
398         // dependence components found above, which means that dependence is
399         // violated by the default hyper-rect tiling method.
400         LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated "
401                                    "for dependence at depth: "
402                                 << Twine(d) << " between:\n";);
403         LLVM_DEBUG(srcAccess.opInst->dump(););
404         LLVM_DEBUG(dstAccess.opInst->dump(););
405         for (unsigned k = 0, e = depComps.size(); k < e; k++) {
406           DependenceComponent depComp = depComps[k];
407           if (depComp.lb.hasValue() && depComp.ub.hasValue() &&
408               depComp.lb.getValue() < depComp.ub.getValue() &&
409               depComp.ub.getValue() < 0) {
410             LLVM_DEBUG(llvm::dbgs()
411                        << "Dependence component lb = "
412                        << Twine(depComp.lb.getValue())
413                        << " ub = " << Twine(depComp.ub.getValue())
414                        << " is negative  at depth: " << Twine(d)
415                        << " and thus violates the legality rule.\n");
416             return false;
417           }
418         }
419       }
420     }
421   }
422 
423   return true;
424 }
425 
426 /// Checks whether hyper-rectangular loop tiling of the nest
427 /// represented by `origLoops` is valid. The validity condition is from Irigoin
428 /// and Triolet, which states that two tiles cannot depend on each other. We
429 /// simplify such condition to just checking whether there is any negative
430 /// dependence direction, since we have the prior knowledge that the tiling
431 /// results will be hyper-rectangles, which are scheduled in the
432 /// lexicographically increasing order on the vector of loop indices. This
433 /// function will return failure when any dependence component is negative along
434 /// any of `origLoops`.
435 LogicalResult
436 checkTilingLegality(MutableArrayRef<mlir::AffineForOp> origLoops) {
437   return success(checkTilingLegalityImpl(origLoops));
438 }
439 
440 /// Checks whether a loop nest is hyper-rectangular or not.
441 LogicalResult checkIfHyperRectangular(MutableArrayRef<AffineForOp> input) {
442   FlatAffineValueConstraints cst;
443   SmallVector<Operation *, 8> ops(input.begin(), input.end());
444   // 0-d or 1-d is trivially hyper-rectangular.
445   if (input.size() <= 1)
446     return success();
447   if (failed(getIndexSet(ops, &cst))) {
448     LLVM_DEBUG(llvm::dbgs() << "Index set computation failed!\n");
449     return failure();
450   }
451   if (!cst.isHyperRectangular(0, input.size())) {
452     LLVM_DEBUG(llvm::dbgs()
453                << "Non-hyperrectangular nests not supported for tiling!\n");
454     return failure();
455   }
456   return success();
457 }
458 
459 /// Check if the input nest is supported for tiling and whether tiling would be
460 /// legal or not.
461 template <typename t>
462 LogicalResult performPreTilingChecks(MutableArrayRef<AffineForOp> input,
463                                      ArrayRef<t> tileSizes) {
464   assert(input.size() == tileSizes.size() && "Too few/many tile sizes");
465 
466   if (llvm::any_of(input,
467                    [](AffineForOp op) { return op.getNumResults() > 0; })) {
468     LLVM_DEBUG(llvm::dbgs()
469                << "Cannot tile nest where a loop has yield values\n");
470     return failure();
471   }
472 
473   // Check if the supplied `for` ops are all successively nested.
474   if (!isPerfectlyNested(input)) {
475     LLVM_DEBUG(llvm::dbgs() << "input loops not perfectly nested");
476     return failure();
477   }
478 
479   if (failed(checkIfHyperRectangular(input)))
480     return failure();
481 
482   // Check if tiling is legal.
483   if (failed(checkTilingLegality(input))) {
484     input[0].emitRemark("tiling code is illegal due to dependences");
485     return failure();
486   }
487 
488   return success();
489 }
490 
491 /// Move the loop body of AffineForOp 'src' from 'src' into the specified
492 /// location in destination's body, ignoring the terminator.
493 static void moveLoopBodyImpl(AffineForOp src, AffineForOp dest,
494                              Block::iterator loc) {
495   auto &ops = src.getBody()->getOperations();
496   dest.getBody()->getOperations().splice(loc, ops, ops.begin(),
497                                          std::prev(ops.end()));
498 }
499 
500 /// Move the loop body of AffineForOp 'src' from 'src' to the start of dest
501 /// body.
502 void moveLoopBody(AffineForOp src, AffineForOp dest) {
503   moveLoopBodyImpl(src, dest, dest.getBody()->begin());
504 }
505 
506 /// Constructs tiled loop nest, without setting the loop bounds and move the
507 /// body of the original loop nest to the tiled loop nest.
508 void constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops,
509                             AffineForOp rootAffineForOp, unsigned width,
510                             MutableArrayRef<AffineForOp> tiledLoops) {
511   Location loc = rootAffineForOp.getLoc();
512 
513   // The outermost among the loops as we add more..
514   Operation *topLoop = rootAffineForOp.getOperation();
515   AffineForOp innermostPointLoop;
516 
517   // Add intra-tile (or point) loops.
518   for (unsigned i = 0; i < width; i++) {
519     OpBuilder b(topLoop);
520     // Loop bounds will be set later.
521     AffineForOp pointLoop = b.create<AffineForOp>(loc, 0, 0);
522     pointLoop.getBody()->getOperations().splice(
523         pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
524         topLoop);
525     tiledLoops[2 * width - 1 - i] = pointLoop;
526     topLoop = pointLoop.getOperation();
527     if (i == 0)
528       innermostPointLoop = pointLoop;
529   }
530 
531   // Add tile space loops;
532   for (unsigned i = width; i < 2 * width; i++) {
533     OpBuilder b(topLoop);
534     // Loop bounds will be set later.
535     AffineForOp tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0);
536     tileSpaceLoop.getBody()->getOperations().splice(
537         tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
538         topLoop);
539     tiledLoops[2 * width - i - 1] = tileSpaceLoop;
540     topLoop = tileSpaceLoop.getOperation();
541   }
542 
543   // Move the loop body of the original nest to the new one.
544   moveLoopBody(origLoops.back(), innermostPointLoop);
545 }
546 
547 /// Set lower and upper bounds of intra-tile loops for parametric tiling.
548 //  TODO: Handle non-constant lower bounds.
549 static void setIntraTileBoundsParametric(OpBuilder &b, AffineForOp origLoop,
550                                          AffineForOp newInterTileLoop,
551                                          AffineForOp newIntraTileLoop,
552                                          Value tileSize) {
553   // The lower bound for the intra-tile loop is represented by an affine map
554   // as (%i, %t0)->((%i - %origlb) * %t0 + %origlb). Similarly, the upper bound
555   // for the intra-tile loop is represented by an affine map as (%i, %t0)->((%i
556   // - %origlb) * %t0) + (%t0 * %origLoopStep) + %origlb), where %i is loop IV
557   // of the corresponding inter-tile loop, %t0 is the corresponding tiling
558   // parameter, %origlb is lower bound and %origLoopStep is the loop step of the
559   // corresponding inter-tile loop.
560 
561   assert(origLoop.hasConstantLowerBound() &&
562          "expected input loops to have constant lower bound.");
563 
564   // Get lower bound of original loop as an affine expression.
565   AffineExpr origLowerBoundExpr;
566   origLowerBoundExpr =
567       b.getAffineConstantExpr(origLoop.getConstantLowerBound());
568 
569   // Add dim operands from original lower/upper bound.
570   SmallVector<Value, 4> lbOperands, ubOperands;
571   AffineBound lb = origLoop.getLowerBound();
572   AffineBound ub = origLoop.getUpperBound();
573   lbOperands.reserve(lb.getNumOperands() + 2);
574   ubOperands.reserve(ub.getNumOperands() + 2);
575   AffineMap origLbMap = lb.getMap();
576   AffineMap origUbMap = ub.getMap();
577   for (unsigned j = 0, e = origLbMap.getNumDims(); j < e; ++j)
578     lbOperands.push_back(lb.getOperand(j));
579   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
580     ubOperands.push_back(ub.getOperand(j));
581 
582   // Add a new dim operand in lb/ubOperands corresponding to the origLoop
583   // IV.
584   lbOperands.push_back(newInterTileLoop.getInductionVar());
585   ubOperands.push_back(newInterTileLoop.getInductionVar());
586 
587   // Get loop IV as an affine expression for lower/upper bound. Size of
588   // lb/ubOperands is guaranteed to be atleast one.
589   AffineExpr lbLoopIvExpr = b.getAffineDimExpr(lbOperands.size() - 1);
590   AffineExpr ubLoopIvExpr = b.getAffineDimExpr(ubOperands.size() - 1);
591 
592   // Add symbol operands from original lower/upper bound.
593   for (unsigned j = 0, e = origLbMap.getNumSymbols(); j < e; ++j)
594     lbOperands.push_back(lb.getOperand(origLbMap.getNumDims() + j));
595   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
596     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
597 
598   // Add a new symbol operand which is the tile size for this loop.
599   lbOperands.push_back(tileSize);
600   ubOperands.push_back(tileSize);
601 
602   SmallVector<AffineExpr, 4> lbBoundExprs;
603   SmallVector<AffineExpr, 4> ubBoundExprs;
604   lbBoundExprs.reserve(origLbMap.getNumResults());
605   ubBoundExprs.reserve(origUbMap.getNumResults());
606 
607   // Get tiling parameter as an affine expression for lb/ub.
608   AffineExpr lbTileParameter = b.getAffineSymbolExpr(origLbMap.getNumSymbols());
609   AffineExpr ubTileParameter = b.getAffineSymbolExpr(origUbMap.getNumSymbols());
610 
611   // Insert lb as inter-tile ((loop IV - origlb) * tilingParameter) + origlb.
612   lbBoundExprs.push_back(
613       ((lbLoopIvExpr - origLowerBoundExpr) * lbTileParameter) +
614       origLowerBoundExpr);
615 
616   // Get the origLoopStep as an affine expression.
617   AffineExpr origLoopStep = b.getAffineConstantExpr(origLoop.getStep());
618 
619   // Insert ub as inter-tile ((loop IV - origlb) * tilingParameter) +
620   // (tilingParameter * origLoopStep) + origlb.
621   ubBoundExprs.push_back(
622       ((ubLoopIvExpr - origLowerBoundExpr) * ubTileParameter) +
623       (ubTileParameter * origLoopStep) + origLowerBoundExpr);
624 
625   ubBoundExprs.append(origUbMap.getResults().begin(),
626                       origUbMap.getResults().end());
627 
628   AffineMap lbMap =
629       AffineMap::get(origLbMap.getNumDims() + 1, origLbMap.getNumSymbols() + 1,
630                      lbBoundExprs, b.getContext());
631   newIntraTileLoop.setLowerBound(lbOperands, lbMap);
632 
633   AffineMap ubMap =
634       AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols() + 1,
635                      ubBoundExprs, b.getContext());
636   newIntraTileLoop.setUpperBound(ubOperands, ubMap);
637 
638   // Original loop step must be preserved.
639   newIntraTileLoop.setStep(origLoop.getStep());
640 }
641 
642 /// Set lower and upper bounds of inter-tile loops for parametric tiling.
643 //  TODO: Handle non-constant lower bounds.
644 static void setInterTileBoundsParametric(OpBuilder &b, AffineForOp origLoop,
645                                          AffineForOp newLoop, Value tileSize) {
646   OperandRange newLbOperands = origLoop.getLowerBoundOperands();
647 
648   // The lower bounds for inter-tile loops are same as the corresponding lower
649   // bounds of original loops.
650   newLoop.setLowerBound(newLbOperands, origLoop.getLowerBoundMap());
651 
652   // The new upper bound map for inter-tile loops, assuming constant lower
653   // bounds, are now originalLowerBound + ceildiv((originalUpperBound -
654   // originalLowerBound), tiling parameter); where tiling parameter is the
655   // respective tile size for that loop. For e.g. if the original ubmap was
656   // ()->(1024), the new map will be
657   // ()[s0]->(ceildiv((1024 -lb) % s0)), where s0 is the tiling parameter.
658   // Therefore a new symbol operand is inserted in the map and the result
659   // expression is overwritten.
660 
661   assert(origLoop.hasConstantLowerBound() &&
662          "expected input loops to have constant lower bound.");
663 
664   // Get lower bound of original loop as an affine expression.
665   AffineExpr origLowerBoundExpr;
666   origLowerBoundExpr =
667       b.getAffineConstantExpr(origLoop.getConstantLowerBound());
668 
669   // Add dim operands from original upper bound.
670   SmallVector<Value, 4> ubOperands;
671   AffineBound ub = origLoop.getUpperBound();
672   ubOperands.reserve(ub.getNumOperands() + 1);
673   AffineMap origUbMap = ub.getMap();
674   for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
675     ubOperands.push_back(ub.getOperand(j));
676 
677   // Add symbol operands from original upper bound.
678   for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
679     ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
680 
681   // Add a new symbol operand which is the tile size for this loop.
682   ubOperands.push_back(tileSize);
683 
684   // Get tiling parameter as an affine expression.
685   AffineExpr tileParameter = b.getAffineSymbolExpr(origUbMap.getNumSymbols());
686 
687   SmallVector<AffineExpr, 4> boundExprs;
688   boundExprs.reserve(origUbMap.getNumResults());
689   int64_t origUpperBound;
690   AffineExpr origUpperBoundExpr;
691 
692   // If upper bound for the original loop is constant, then the constant can
693   // be obtained as an affine expression straight away.
694   if (origLoop.hasConstantUpperBound()) {
695     origUpperBound = origLoop.getConstantUpperBound();
696 
697     // Get original constant upper bound as an affine expression.
698     origUpperBoundExpr = b.getAffineConstantExpr(origUpperBound);
699 
700     // Insert the bound as originalLowerBoundceildiv((originalUpperBound -
701     // originalLowerBound), tilingParameter).
702     boundExprs.push_back(
703         origLowerBoundExpr +
704         (origUpperBoundExpr - origLowerBoundExpr).ceilDiv(tileParameter));
705   } else {
706     // If upper bound for the original loop is not constant then two cases
707     // are possible, although there handeling is the same, 1.) The result of
708     // ubmap has only one result expression. For e.g.
709     //    affine.for %i = 5 to %ub
710     //
711     // A symbol operand is added which represents the tiling parameter. The
712     // new loop bounds here will be like ()[s0, s1] -> ((s0 - 5) ceildiv s1 + 5)
713     // where 's0' is the original upper bound and 's1' is the tiling
714     // parameter. 2.) When ubMap has more than one result expression. For e.g.
715     //    #map0 = affine_map<()[s0, s1] -> (s0, s1)
716     //    affine.for %i = 5 to min #map0()[%s0, %s1]
717     //
718     // A symbol operand is added which represents the tiling parameter. The
719     // new loop bounds will be like ()[s0, s1, s2] -> ((s0 - 5) ceildiv s2 + 5,
720     // (s1 -5) ceildiv s2 + 5), where s2 is the tiling parameter.
721 
722     // Insert the bounds as originalLowerBound + ceildiv((originalUpperBound -
723     // originalLowerBound), tilingParameter).
724     for (AffineExpr origUpperBoundExpr : origUbMap.getResults())
725       boundExprs.push_back(
726           origLowerBoundExpr +
727           (origUpperBoundExpr - origLowerBoundExpr).ceilDiv(tileParameter));
728   }
729 
730   AffineMap ubMap =
731       AffineMap::get(origUbMap.getNumDims(), origUbMap.getNumSymbols() + 1,
732                      boundExprs, b.getContext());
733   newLoop.setUpperBound(ubOperands, ubMap);
734 
735   // Original loop step must be preserved.
736   newLoop.setStep(origLoop.getStep());
737 }
738 
739 /// Constructs and sets new loop bounds after tiling for the case of
740 /// hyper-rectangular index sets, where the bounds of one dimension do not
741 /// depend on other dimensions and tiling parameters are captured from SSA
742 /// values. Bounds of each dimension can thus be treated independently,
743 /// and deriving the new bounds is much simpler and faster than for the case of
744 /// tiling arbitrary polyhedral shapes.
745 static void constructParametricallyTiledIndexSetHyperRect(
746     MutableArrayRef<AffineForOp> origLoops,
747     MutableArrayRef<AffineForOp> newLoops, ArrayRef<Value> tileSizes) {
748   assert(!origLoops.empty() && "expected atleast one loop in band");
749   assert(origLoops.size() == tileSizes.size() &&
750          "expected tiling parameter for each loop in band.");
751 
752   OpBuilder b(origLoops[0].getOperation());
753   unsigned width = origLoops.size();
754 
755   // Set bounds for tile space loops.
756   for (unsigned i = 0; i < width; ++i) {
757     setInterTileBoundsParametric(b, origLoops[i], newLoops[i], tileSizes[i]);
758   }
759 
760   // Set bounds for intra-tile loops.
761   for (unsigned i = 0; i < width; ++i) {
762     setIntraTileBoundsParametric(b, origLoops[i], newLoops[i],
763                                  newLoops[i + width], tileSizes[i]);
764   }
765 }
766 
767 /// Constructs and sets new loop bounds after tiling for the case of
768 /// hyper-rectangular index sets, where the bounds of one dimension do not
769 /// depend on other dimensions. Bounds of each dimension can thus be treated
770 /// independently, and deriving the new bounds is much simpler and faster
771 /// than for the case of tiling arbitrary polyhedral shapes.
772 static void
773 constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,
774                                 MutableArrayRef<AffineForOp> newLoops,
775                                 ArrayRef<unsigned> tileSizes) {
776   assert(!origLoops.empty());
777   assert(origLoops.size() == tileSizes.size());
778 
779   OpBuilder b(origLoops[0].getOperation());
780   unsigned width = origLoops.size();
781 
782   // Bounds for tile space loops.
783   for (unsigned i = 0; i < width; i++) {
784     OperandRange newLbOperands = origLoops[i].getLowerBoundOperands();
785     OperandRange newUbOperands = origLoops[i].getUpperBoundOperands();
786     newLoops[i].setLowerBound(newLbOperands, origLoops[i].getLowerBoundMap());
787     newLoops[i].setUpperBound(newUbOperands, origLoops[i].getUpperBoundMap());
788     // If the step size of original loop is x and tileSize is y then after
789     // tiling the tile space loops' step size becomes x*y.
790     newLoops[i].setStep(tileSizes[i] * origLoops[i].getStep());
791   }
792   // Bounds for intra-tile loops.
793   for (unsigned i = 0; i < width; i++) {
794     int64_t largestDiv = getLargestDivisorOfTripCount(origLoops[i]);
795     Optional<uint64_t> mayBeConstantCount = getConstantTripCount(origLoops[i]);
796     // The lower bound is just the tile-space loop.
797     AffineMap lbMap = b.getDimIdentityMap();
798     newLoops[width + i].setLowerBound(
799         /*operands=*/newLoops[i].getInductionVar(), lbMap);
800     // The step sizes of intra-tile loops is just the original loops' step size.
801     newLoops[width + i].setStep(origLoops[i].getStep());
802 
803     // Set the upper bound.
804     if (mayBeConstantCount && mayBeConstantCount.getValue() < tileSizes[i]) {
805       // Trip count is less than the tile size: upper bound is lower bound +
806       // trip count * stepSize.
807       AffineMap ubMap = b.getSingleDimShiftAffineMap(
808           mayBeConstantCount.getValue() * origLoops[i].getStep());
809       newLoops[width + i].setUpperBound(
810           /*operands=*/newLoops[i].getInductionVar(), ubMap);
811     } else if (largestDiv % tileSizes[i] != 0) {
812       // Intra-tile loop ii goes from i to min(i + tileSize * stepSize, ub_i).
813       // Construct the upper bound map; the operands are the original operands
814       // with 'i' (tile-space loop) appended to it. The new upper bound map is
815       // the original one with an additional expression i + tileSize * stepSize
816       // appended.
817 
818       // Add dim operands from original upper bound.
819       SmallVector<Value, 4> ubOperands;
820       AffineBound ub = origLoops[i].getUpperBound();
821       ubOperands.reserve(ub.getNumOperands() + 1);
822       AffineMap origUbMap = ub.getMap();
823       for (unsigned j = 0, e = origUbMap.getNumDims(); j < e; ++j)
824         ubOperands.push_back(ub.getOperand(j));
825 
826       // Add dim operand for new loop upper bound.
827       ubOperands.push_back(newLoops[i].getInductionVar());
828 
829       // Add symbol operands from original upper bound.
830       for (unsigned j = 0, e = origUbMap.getNumSymbols(); j < e; ++j)
831         ubOperands.push_back(ub.getOperand(origUbMap.getNumDims() + j));
832 
833       SmallVector<AffineExpr, 4> boundExprs;
834       boundExprs.reserve(1 + origUbMap.getNumResults());
835       AffineExpr dim = b.getAffineDimExpr(origUbMap.getNumDims());
836       // The new upper bound map is the original one with an additional
837       // expression i + tileSize * stepSize (of original loop) appended.
838       boundExprs.push_back(dim + tileSizes[i] * origLoops[i].getStep());
839       boundExprs.append(origUbMap.getResults().begin(),
840                         origUbMap.getResults().end());
841       AffineMap ubMap =
842           AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols(),
843                          boundExprs, b.getContext());
844       newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap);
845     } else {
846       // No need of the min expression.
847       AffineExpr dim = b.getAffineDimExpr(0);
848       AffineMap ubMap =
849           AffineMap::get(1, 0, dim + tileSizes[i] * origLoops[i].getStep());
850       newLoops[width + i].setUpperBound(newLoops[i].getInductionVar(), ubMap);
851     }
852   }
853 }
854 
855 /// Tiles the specified band of perfectly nested loops creating tile-space loops
856 /// and intra-tile loops. A band is a contiguous set of loops.
857 //  TODO: handle non hyper-rectangular spaces.
858 LogicalResult
859 mlir::tilePerfectlyNested(MutableArrayRef<AffineForOp> input,
860                           ArrayRef<unsigned> tileSizes,
861                           SmallVectorImpl<AffineForOp> *tiledNest) {
862   if (input.empty())
863     return success();
864 
865   if (failed(performPreTilingChecks(input, tileSizes)))
866     return failure();
867 
868   MutableArrayRef<AffineForOp> origLoops = input;
869   AffineForOp rootAffineForOp = origLoops[0];
870 
871   // Note that width is at least one since the band isn't empty.
872   unsigned width = input.size();
873   SmallVector<AffineForOp, 6> tiledLoops(2 * width);
874 
875   // Construct a tiled loop nest without setting their bounds. Bounds are
876   // set later.
877   constructTiledLoopNest(origLoops, rootAffineForOp, width, tiledLoops);
878 
879   SmallVector<Value, 8> origLoopIVs;
880   extractForInductionVars(input, &origLoopIVs);
881 
882   // Set loop bounds for the tiled loop nest.
883   constructTiledIndexSetHyperRect(origLoops, tiledLoops, tileSizes);
884 
885   // Replace original IVs with intra-tile loop IVs.
886   for (unsigned i = 0; i < width; i++)
887     origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar());
888 
889   // Erase the old loop nest.
890   rootAffineForOp.erase();
891 
892   if (tiledNest)
893     *tiledNest = std::move(tiledLoops);
894 
895   return success();
896 }
897 
898 /// Tiles the specified band of perfectly nested loops creating tile-space
899 /// loops and intra-tile loops, using SSA values as tiling parameters. A band
900 /// is a contiguous set of loops.
901 //  TODO: handle non hyper-rectangular spaces.
902 LogicalResult
903 mlir::tilePerfectlyNestedParametric(MutableArrayRef<AffineForOp> input,
904                                     ArrayRef<Value> tileSizes,
905                                     SmallVectorImpl<AffineForOp> *tiledNest) {
906   if (input.empty())
907     return success();
908 
909   if (failed(performPreTilingChecks(input, tileSizes)))
910     return failure();
911 
912   MutableArrayRef<AffineForOp> origLoops = input;
913   AffineForOp rootAffineForOp = origLoops[0];
914   unsigned width = input.size();
915   SmallVector<AffineForOp, 6> tiledLoops(2 * width);
916 
917   // Construct a tiled loop nest without setting their bounds. Bounds are
918   // set later.
919   constructTiledLoopNest(origLoops, rootAffineForOp, width, tiledLoops);
920 
921   SmallVector<Value, 8> origLoopIVs;
922   extractForInductionVars(input, &origLoopIVs);
923 
924   // Set loop bounds for the tiled loop nest.
925   constructParametricallyTiledIndexSetHyperRect(origLoops, tiledLoops,
926                                                 tileSizes);
927 
928   // Replace original IVs with intra-tile loop IVs.
929   for (unsigned i = 0; i < width; i++)
930     origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar());
931 
932   // Erase the old loop nest.
933   rootAffineForOp.erase();
934 
935   if (tiledNest)
936     *tiledNest = std::move(tiledLoops);
937 
938   return success();
939 }
940 
941 /// Get perfectly nested sequence of loops starting at root of loop nest
942 /// (the first op being another AffineFor, and the second op - a terminator).
943 /// A loop is perfectly nested iff: the first op in the loop's body is another
944 /// AffineForOp, and the second op is a terminator).
945 void mlir::getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
946                                    AffineForOp root) {
947   for (unsigned i = 0; i < std::numeric_limits<unsigned>::max(); ++i) {
948     nestedLoops.push_back(root);
949     Block &body = root.getRegion().front();
950     if (body.begin() != std::prev(body.end(), 2))
951       return;
952 
953     root = dyn_cast<AffineForOp>(&body.front());
954     if (!root)
955       return;
956   }
957 }
958 
959 /// Identify valid and profitable bands of loops to tile. This is currently just
960 /// a temporary placeholder to test the mechanics of tiled code generation.
961 /// Returns all maximal outermost perfect loop nests to tile.
962 void mlir::getTileableBands(FuncOp f,
963                             std::vector<SmallVector<AffineForOp, 6>> *bands) {
964   // Get maximal perfect nest of 'affine.for' insts starting from root
965   // (inclusive).
966   for (AffineForOp forOp : f.getOps<AffineForOp>()) {
967     SmallVector<AffineForOp, 6> band;
968     getPerfectlyNestedLoops(band, forOp);
969     bands->push_back(band);
970   }
971 }
972 
973 /// Unrolls this loop completely.
974 LogicalResult mlir::loopUnrollFull(AffineForOp forOp) {
975   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
976   if (mayBeConstantTripCount.hasValue()) {
977     uint64_t tripCount = mayBeConstantTripCount.getValue();
978     if (tripCount == 0)
979       return success();
980     if (tripCount == 1)
981       return promoteIfSingleIteration(forOp);
982     return loopUnrollByFactor(forOp, tripCount);
983   }
984   return failure();
985 }
986 
987 /// Unrolls this loop by the specified factor or by the trip count (if constant)
988 /// whichever is lower.
989 LogicalResult mlir::loopUnrollUpToFactor(AffineForOp forOp,
990                                          uint64_t unrollFactor) {
991   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
992   if (mayBeConstantTripCount.hasValue() &&
993       mayBeConstantTripCount.getValue() < unrollFactor)
994     return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue());
995   return loopUnrollByFactor(forOp, unrollFactor);
996 }
997 
998 /// Generates unrolled copies of AffineForOp 'loopBodyBlock', with associated
999 /// 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap 'forOpIV' for each
1000 /// unrolled body. If specified, annotates the Ops in each unrolled iteration
1001 /// using annotateFn.
1002 static void generateUnrolledLoop(
1003     Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
1004     function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
1005     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
1006     ValueRange iterArgs, ValueRange yieldedValues) {
1007   // Builder to insert unrolled bodies just before the terminator of the body of
1008   // 'forOp'.
1009   auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
1010 
1011   if (!annotateFn)
1012     annotateFn = [](unsigned, Operation *, OpBuilder) {};
1013 
1014   // Keep a pointer to the last non-terminator operation in the original block
1015   // so that we know what to clone (since we are doing this in-place).
1016   Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
1017 
1018   // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
1019   SmallVector<Value, 4> lastYielded(yieldedValues);
1020 
1021   for (unsigned i = 1; i < unrollFactor; i++) {
1022     BlockAndValueMapping operandMap;
1023 
1024     // Prepare operand map.
1025     operandMap.map(iterArgs, lastYielded);
1026 
1027     // If the induction variable is used, create a remapping to the value for
1028     // this unrolled instance.
1029     if (!forOpIV.use_empty()) {
1030       Value ivUnroll = ivRemapFn(i, forOpIV, builder);
1031       operandMap.map(forOpIV, ivUnroll);
1032     }
1033 
1034     // Clone the original body of 'forOp'.
1035     for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
1036       Operation *clonedOp = builder.clone(*it, operandMap);
1037       annotateFn(i, clonedOp, builder);
1038     }
1039 
1040     // Update yielded values.
1041     for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
1042       lastYielded[i] = operandMap.lookup(yieldedValues[i]);
1043   }
1044 
1045   // Make sure we annotate the Ops in the original body. We do this last so that
1046   // any annotations are not copied into the cloned Ops above.
1047   for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
1048     annotateFn(0, &*it, builder);
1049 
1050   // Update operands of the yield statement.
1051   loopBodyBlock->getTerminator()->setOperands(lastYielded);
1052 }
1053 
1054 /// Helper to generate cleanup loop for unroll or unroll-and-jam when the trip
1055 /// count is not a multiple of `unrollFactor`.
1056 static LogicalResult generateCleanupLoopForUnroll(AffineForOp forOp,
1057                                                   uint64_t unrollFactor) {
1058   // Insert the cleanup loop right after 'forOp'.
1059   OpBuilder builder(forOp->getBlock(), std::next(Block::iterator(forOp)));
1060   auto cleanupForOp = cast<AffineForOp>(builder.clone(*forOp));
1061 
1062   // Update uses of `forOp` results. `cleanupForOp` should use `forOp` result
1063   // and produce results for the original users of `forOp` results.
1064   auto results = forOp.getResults();
1065   auto cleanupResults = cleanupForOp.getResults();
1066   auto cleanupIterOperands = cleanupForOp.getIterOperands();
1067 
1068   for (auto e : llvm::zip(results, cleanupResults, cleanupIterOperands)) {
1069     std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
1070     cleanupForOp->replaceUsesOfWith(std::get<2>(e), std::get<0>(e));
1071   }
1072 
1073   AffineMap cleanupMap;
1074   SmallVector<Value, 4> cleanupOperands;
1075   getCleanupLoopLowerBound(forOp, unrollFactor, cleanupMap, cleanupOperands);
1076   if (!cleanupMap)
1077     return failure();
1078 
1079   cleanupForOp.setLowerBound(cleanupOperands, cleanupMap);
1080   // Promote the loop body up if this has turned into a single iteration loop.
1081   (void)promoteIfSingleIteration(cleanupForOp);
1082 
1083   // Adjust upper bound of the original loop; this is the same as the lower
1084   // bound of the cleanup loop.
1085   forOp.setUpperBound(cleanupOperands, cleanupMap);
1086   return success();
1087 }
1088 
1089 /// Unrolls this loop by the specified factor. Returns success if the loop
1090 /// is successfully unrolled.
1091 LogicalResult mlir::loopUnrollByFactor(
1092     AffineForOp forOp, uint64_t unrollFactor,
1093     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
1094   assert(unrollFactor > 0 && "unroll factor should be positive");
1095 
1096   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1097   if (unrollFactor == 1) {
1098     if (mayBeConstantTripCount.hasValue() &&
1099         mayBeConstantTripCount.getValue() == 1 &&
1100         failed(promoteIfSingleIteration(forOp)))
1101       return failure();
1102     return success();
1103   }
1104 
1105   // Nothing in the loop body other than the terminator.
1106   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
1107     return success();
1108 
1109   // If the trip count is lower than the unroll factor, no unrolled body.
1110   // TODO: option to specify cleanup loop unrolling.
1111   if (mayBeConstantTripCount.hasValue() &&
1112       mayBeConstantTripCount.getValue() < unrollFactor)
1113     return failure();
1114 
1115   // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
1116   if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
1117     // Loops where the lower bound is a max expression or the upper bound is
1118     // a min expression and the trip count doesn't divide the unroll factor
1119     // can't be unrolled since the lower bound of the cleanup loop in such cases
1120     // cannot be expressed as an affine function or a max over affine functions.
1121     if (forOp.getLowerBoundMap().getNumResults() != 1 ||
1122         forOp.getUpperBoundMap().getNumResults() != 1)
1123       return failure();
1124     if (failed(generateCleanupLoopForUnroll(forOp, unrollFactor)))
1125       assert(false && "cleanup loop lower bound map for single result lower "
1126                       "and upper bound maps can always be determined");
1127   }
1128 
1129   ValueRange iterArgs(forOp.getRegionIterArgs());
1130   auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
1131 
1132   // Scale the step of loop being unrolled by unroll factor.
1133   int64_t step = forOp.getStep();
1134   forOp.setStep(step * unrollFactor);
1135   generateUnrolledLoop(
1136       forOp.getBody(), forOp.getInductionVar(), unrollFactor,
1137       [&](unsigned i, Value iv, OpBuilder b) {
1138         // iv' = iv + i * step
1139         auto d0 = b.getAffineDimExpr(0);
1140         auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
1141         return b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, iv);
1142       },
1143       /*annotateFn=*/annotateFn,
1144       /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues);
1145 
1146   // Promote the loop body up if this has turned into a single iteration loop.
1147   (void)promoteIfSingleIteration(forOp);
1148   return success();
1149 }
1150 
1151 LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp,
1152                                             uint64_t unrollJamFactor) {
1153   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1154   if (mayBeConstantTripCount.hasValue() &&
1155       mayBeConstantTripCount.getValue() < unrollJamFactor)
1156     return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue());
1157   return loopUnrollJamByFactor(forOp, unrollJamFactor);
1158 }
1159 
1160 /// Check if all control operands of all loops are defined outside of `forOp`
1161 /// and return false if not.
1162 static bool areInnerBoundsInvariant(AffineForOp forOp) {
1163   auto walkResult = forOp.walk([&](AffineForOp aForOp) {
1164     for (auto controlOperand : aForOp.getControlOperands()) {
1165       if (!forOp.isDefinedOutsideOfLoop(controlOperand))
1166         return WalkResult::interrupt();
1167     }
1168     return WalkResult::advance();
1169   });
1170   return !walkResult.wasInterrupted();
1171 }
1172 
1173 // Gathers all maximal sub-blocks of operations that do not themselves
1174 // include a for op (a operation could have a descendant for op though
1175 // in its tree).  Ignore the block terminators.
1176 struct JamBlockGatherer {
1177   // Store iterators to the first and last op of each sub-block found.
1178   std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
1179 
1180   // This is a linear time walk.
1181   void walk(Operation *op) {
1182     for (auto &region : op->getRegions())
1183       for (auto &block : region)
1184         walk(block);
1185   }
1186 
1187   void walk(Block &block) {
1188     for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
1189       auto subBlockStart = it;
1190       while (it != e && !isa<AffineForOp>(&*it))
1191         ++it;
1192       if (it != subBlockStart)
1193         subBlocks.emplace_back(subBlockStart, std::prev(it));
1194       // Process all for ops that appear next.
1195       while (it != e && isa<AffineForOp>(&*it))
1196         walk(&*it++);
1197     }
1198   }
1199 };
1200 
1201 /// Unrolls and jams this loop by the specified factor.
1202 LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
1203                                           uint64_t unrollJamFactor) {
1204   assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
1205 
1206   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1207   if (unrollJamFactor == 1) {
1208     if (mayBeConstantTripCount.hasValue() &&
1209         mayBeConstantTripCount.getValue() == 1 &&
1210         failed(promoteIfSingleIteration(forOp)))
1211       return failure();
1212     return success();
1213   }
1214 
1215   // Nothing in the loop body other than the terminator.
1216   if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
1217     return success();
1218 
1219   // If the trip count is lower than the unroll jam factor, no unroll jam.
1220   if (mayBeConstantTripCount.hasValue() &&
1221       mayBeConstantTripCount.getValue() < unrollJamFactor) {
1222     LLVM_DEBUG(llvm::dbgs() << "[failed] trip count < unroll-jam factor\n");
1223     return failure();
1224   }
1225 
1226   // If any control operand of any inner loop of `forOp` is defined within
1227   // `forOp`, no unroll jam.
1228   if (!areInnerBoundsInvariant(forOp))
1229     return failure();
1230 
1231   // Gather all sub-blocks to jam upon the loop being unrolled.
1232   JamBlockGatherer jbg;
1233   jbg.walk(forOp);
1234   auto &subBlocks = jbg.subBlocks;
1235 
1236   // Collect loops with iter_args.
1237   SmallVector<AffineForOp, 4> loopsWithIterArgs;
1238   forOp.walk([&](AffineForOp aForOp) {
1239     if (aForOp.getNumIterOperands() > 0)
1240       loopsWithIterArgs.push_back(aForOp);
1241   });
1242 
1243   // Get supported reductions to be used for creating reduction ops at the end.
1244   SmallVector<LoopReduction> reductions;
1245   if (forOp.getNumIterOperands() > 0)
1246     getSupportedReductions(forOp, reductions);
1247 
1248   // Generate the cleanup loop if trip count isn't a multiple of
1249   // unrollJamFactor.
1250   if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) {
1251     // Loops where the lower bound is a max expression or the upper bound is
1252     // a min expression and the trip count doesn't divide the unroll factor
1253     // can't be unrolled since the lower bound of the cleanup loop in such cases
1254     // cannot be expressed as an affine function or a max over affine functions.
1255     if (forOp.getLowerBoundMap().getNumResults() != 1 ||
1256         forOp.getUpperBoundMap().getNumResults() != 1)
1257       return failure();
1258     if (failed(generateCleanupLoopForUnroll(forOp, unrollJamFactor)))
1259       assert(false && "cleanup loop lower bound map for single result lower "
1260                       "and upper bound maps can always be determined");
1261   }
1262 
1263   // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
1264   // iteration. There are (`unrollJamFactor` - 1) iterations.
1265   SmallVector<BlockAndValueMapping, 4> operandMaps(unrollJamFactor - 1);
1266 
1267   // For any loop with iter_args, replace it with a new loop that has
1268   // `unrollJamFactor` copies of its iterOperands, iter_args and yield
1269   // operands.
1270   SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
1271   OpBuilder builder(forOp.getContext());
1272   for (AffineForOp oldForOp : loopsWithIterArgs) {
1273     SmallVector<Value, 4> dupIterOperands, dupIterArgs, dupYieldOperands;
1274     ValueRange oldIterOperands = oldForOp.getIterOperands();
1275     ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
1276     ValueRange oldYieldOperands =
1277         cast<AffineYieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
1278     // Get additional iterOperands, iterArgs, and yield operands. We will
1279     // fix iterOperands and yield operands after cloning of sub-blocks.
1280     for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1281       dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
1282       dupIterArgs.append(oldIterArgs.begin(), oldIterArgs.end());
1283       dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
1284     }
1285     // Create a new loop with additional iterOperands, iter_args and yield
1286     // operands. This new loop will take the loop body of the original loop.
1287     AffineForOp newForOp = mlir::replaceForOpWithNewYields(
1288         builder, oldForOp, dupIterOperands, dupYieldOperands, dupIterArgs);
1289     newLoopsWithIterArgs.push_back(newForOp);
1290     // `forOp` has been replaced with a new loop.
1291     if (oldForOp == forOp)
1292       forOp = newForOp;
1293     assert(oldForOp.use_empty() && "old for op should not have any user");
1294     oldForOp.erase();
1295     // Update `operandMaps` for `newForOp` iterArgs and results.
1296     ValueRange newIterArgs = newForOp.getRegionIterArgs();
1297     unsigned oldNumIterArgs = oldIterArgs.size();
1298     ValueRange newResults = newForOp.getResults();
1299     unsigned oldNumResults = newResults.size() / unrollJamFactor;
1300     assert(oldNumIterArgs == oldNumResults &&
1301            "oldNumIterArgs must be the same as oldNumResults");
1302     for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1303       for (unsigned j = 0; j < oldNumIterArgs; ++j) {
1304         // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
1305         // results. Update `operandMaps[i - 1]` to map old iterArgs and results
1306         // to those in the `i`th new set.
1307         operandMaps[i - 1].map(newIterArgs[j],
1308                                newIterArgs[i * oldNumIterArgs + j]);
1309         operandMaps[i - 1].map(newResults[j],
1310                                newResults[i * oldNumResults + j]);
1311       }
1312     }
1313   }
1314 
1315   // Scale the step of loop being unroll-jammed by the unroll-jam factor.
1316   int64_t step = forOp.getStep();
1317   forOp.setStep(step * unrollJamFactor);
1318 
1319   auto forOpIV = forOp.getInductionVar();
1320   // Unroll and jam (appends unrollJamFactor - 1 additional copies).
1321   for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1322     for (auto &subBlock : subBlocks) {
1323       // Builder to insert unroll-jammed bodies. Insert right at the end of
1324       // sub-block.
1325       OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
1326 
1327       // If the induction variable is used, create a remapping to the value for
1328       // this unrolled instance.
1329       if (!forOpIV.use_empty()) {
1330         // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
1331         auto d0 = builder.getAffineDimExpr(0);
1332         auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
1333         auto ivUnroll =
1334             builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
1335         operandMaps[i - 1].map(forOpIV, ivUnroll);
1336       }
1337       // Clone the sub-block being unroll-jammed.
1338       for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
1339         builder.clone(*it, operandMaps[i - 1]);
1340     }
1341     // Fix iterOperands and yield op operands of newly created loops.
1342     for (auto newForOp : newLoopsWithIterArgs) {
1343       unsigned oldNumIterOperands =
1344           newForOp.getNumIterOperands() / unrollJamFactor;
1345       unsigned numControlOperands = newForOp.getNumControlOperands();
1346       auto yieldOp = cast<AffineYieldOp>(newForOp.getBody()->getTerminator());
1347       unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
1348       assert(oldNumIterOperands == oldNumYieldOperands &&
1349              "oldNumIterOperands must be the same as oldNumYieldOperands");
1350       for (unsigned j = 0; j < oldNumIterOperands; ++j) {
1351         // The `i`th duplication of an old iterOperand or yield op operand
1352         // needs to be replaced with a mapped value from `operandMaps[i - 1]`
1353         // if such mapped value exists.
1354         newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
1355                             operandMaps[i - 1].lookupOrDefault(
1356                                 newForOp.getOperand(numControlOperands + j)));
1357         yieldOp.setOperand(
1358             i * oldNumYieldOperands + j,
1359             operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
1360       }
1361     }
1362   }
1363   if (forOp.getNumResults() > 0) {
1364     // Create reduction ops to combine every `unrollJamFactor` related results
1365     // into one value. For example, for %0:2 = affine.for ... and addf, we add
1366     // %1 = arith.addf %0#0, %0#1, and replace the following uses of %0#0 with
1367     // %1.
1368     builder.setInsertionPointAfter(forOp);
1369     auto loc = forOp.getLoc();
1370     unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
1371     for (LoopReduction &reduction : reductions) {
1372       unsigned pos = reduction.iterArgPosition;
1373       Value lhs = forOp.getResult(pos);
1374       Value rhs;
1375       SmallPtrSet<Operation *, 4> newOps;
1376       for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
1377         rhs = forOp.getResult(i * oldNumResults + pos);
1378         // Create ops based on reduction type.
1379         lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs);
1380         if (!lhs)
1381           return failure();
1382         Operation *op = lhs.getDefiningOp();
1383         assert(op && "Reduction op should have been created");
1384         newOps.insert(op);
1385       }
1386       // Replace all uses except those in newly created reduction ops.
1387       forOp.getResult(pos).replaceAllUsesExcept(lhs, newOps);
1388     }
1389   }
1390 
1391   // Promote the loop body up if this has turned into a single iteration loop.
1392   (void)promoteIfSingleIteration(forOp);
1393   return success();
1394 }
1395 
1396 /// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is
1397 /// nested within 'forOpA' as the only non-terminator operation in its block.
1398 void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) {
1399   assert(&*forOpA.getBody()->begin() == forOpB.getOperation());
1400   auto &forOpABody = forOpA.getBody()->getOperations();
1401   auto &forOpBBody = forOpB.getBody()->getOperations();
1402 
1403   // 1) Splice forOpA's non-terminator operations (which is just forOpB) just
1404   // before forOpA (in ForOpA's parent's block) this should leave 'forOpA's
1405   // body containing only the terminator.
1406   forOpA->getBlock()->getOperations().splice(Block::iterator(forOpA),
1407                                              forOpABody, forOpABody.begin(),
1408                                              std::prev(forOpABody.end()));
1409   // 2) Splice forOpB's non-terminator operations into the beginning of forOpA's
1410   // body (this leaves forOpB's body containing only the terminator).
1411   forOpABody.splice(forOpABody.begin(), forOpBBody, forOpBBody.begin(),
1412                     std::prev(forOpBBody.end()));
1413   // 3) Splice forOpA into the beginning of forOpB's body.
1414   forOpBBody.splice(forOpBBody.begin(), forOpA->getBlock()->getOperations(),
1415                     Block::iterator(forOpA));
1416 }
1417 
1418 // Checks each dependence component against the permutation to see if the
1419 // desired loop interchange would violate dependences by making the
1420 // dependence component lexicographically negative.
1421 static bool checkLoopInterchangeDependences(
1422     const std::vector<SmallVector<DependenceComponent, 2>> &depCompsVec,
1423     ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) {
1424   // Invert permutation map.
1425   unsigned maxLoopDepth = loops.size();
1426   SmallVector<unsigned, 4> loopPermMapInv;
1427   loopPermMapInv.resize(maxLoopDepth);
1428   for (unsigned i = 0; i < maxLoopDepth; ++i)
1429     loopPermMapInv[loopPermMap[i]] = i;
1430 
1431   // Check each dependence component against the permutation to see if the
1432   // desired loop interchange permutation would make the dependence vectors
1433   // lexicographically negative.
1434   // Example 1: [-1, 1][0, 0]
1435   // Example 2: [0, 0][-1, 1]
1436   for (const auto &depComps : depCompsVec) {
1437     assert(depComps.size() >= maxLoopDepth);
1438     // Check if the first non-zero dependence component is positive.
1439     // This iterates through loops in the desired order.
1440     for (unsigned j = 0; j < maxLoopDepth; ++j) {
1441       unsigned permIndex = loopPermMapInv[j];
1442       assert(depComps[permIndex].lb.hasValue());
1443       int64_t depCompLb = depComps[permIndex].lb.getValue();
1444       if (depCompLb > 0)
1445         break;
1446       if (depCompLb < 0)
1447         return false;
1448     }
1449   }
1450   return true;
1451 }
1452 
1453 /// Checks if the loop interchange permutation 'loopPermMap' of the perfectly
1454 /// nested sequence of loops in 'loops' would violate dependences.
1455 bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
1456                                              ArrayRef<unsigned> loopPermMap) {
1457   // Gather dependence components for dependences between all ops in loop nest
1458   // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
1459   assert(loopPermMap.size() == loops.size());
1460   unsigned maxLoopDepth = loops.size();
1461   std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
1462   getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
1463   return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap);
1464 }
1465 
1466 /// Returns true if `loops` is a perfectly nested loop nest, where loops appear
1467 /// in it from outermost to innermost.
1468 bool LLVM_ATTRIBUTE_UNUSED
1469 mlir::isPerfectlyNested(ArrayRef<AffineForOp> loops) {
1470   assert(!loops.empty() && "no loops provided");
1471 
1472   // We already know that the block can't be empty.
1473   auto hasTwoElements = [](Block *block) {
1474     auto secondOpIt = std::next(block->begin());
1475     return secondOpIt != block->end() && &*secondOpIt == &block->back();
1476   };
1477 
1478   auto enclosingLoop = loops.front();
1479   for (auto loop : loops.drop_front()) {
1480     auto parentForOp = dyn_cast<AffineForOp>(loop->getParentOp());
1481     // parentForOp's body should be just this loop and the terminator.
1482     if (parentForOp != enclosingLoop || !hasTwoElements(parentForOp.getBody()))
1483       return false;
1484     enclosingLoop = loop;
1485   }
1486   return true;
1487 }
1488 
1489 // input[i] should move from position i -> permMap[i]. Returns the position in
1490 // `input` that becomes the new outermost loop.
1491 unsigned mlir::permuteLoops(MutableArrayRef<AffineForOp> input,
1492                             ArrayRef<unsigned> permMap) {
1493   assert(input.size() == permMap.size() && "invalid permutation map size");
1494   // Check whether the permutation spec is valid. This is a small vector - we'll
1495   // just sort and check if it's iota.
1496   SmallVector<unsigned, 4> checkPermMap(permMap.begin(), permMap.end());
1497   llvm::sort(checkPermMap);
1498   if (llvm::any_of(llvm::enumerate(checkPermMap),
1499                    [](const auto &en) { return en.value() != en.index(); }))
1500     assert(false && "invalid permutation map");
1501 
1502   // Nothing to do.
1503   if (input.size() < 2)
1504     return 0;
1505 
1506   assert(isPerfectlyNested(input) && "input not perfectly nested");
1507 
1508   // Compute the inverse mapping, invPermMap: since input[i] goes to position
1509   // permMap[i], position i of the permuted nest is at input[invPermMap[i]].
1510   SmallVector<std::pair<unsigned, unsigned>, 4> invPermMap;
1511   for (unsigned i = 0, e = input.size(); i < e; ++i)
1512     invPermMap.push_back({permMap[i], i});
1513   llvm::sort(invPermMap);
1514 
1515   // Move the innermost loop body to the loop that would be the innermost in the
1516   // permuted nest (only if the innermost loop is going to change).
1517   if (permMap.back() != input.size() - 1) {
1518     auto *destBody = input[invPermMap.back().second].getBody();
1519     auto *srcBody = input.back().getBody();
1520     destBody->getOperations().splice(destBody->begin(),
1521                                      srcBody->getOperations(), srcBody->begin(),
1522                                      std::prev(srcBody->end()));
1523   }
1524 
1525   // We'll move each loop in `input` in the reverse order so that its body is
1526   // empty when we are moving it; this incurs zero copies and no erasing.
1527   for (int i = input.size() - 1; i >= 0; --i) {
1528     // If this has to become the outermost loop after permutation, add it to the
1529     // parent block of the original root.
1530     if (permMap[i] == 0) {
1531       // If the root remains the same, nothing to do.
1532       if (i == 0)
1533         continue;
1534       // Make input[i] the new outermost loop moving it into parentBlock.
1535       auto *parentBlock = input[0]->getBlock();
1536       parentBlock->getOperations().splice(Block::iterator(input[0]),
1537                                           input[i]->getBlock()->getOperations(),
1538                                           Block::iterator(input[i]));
1539       continue;
1540     }
1541 
1542     // If the parent in the permuted order is the same as in the original,
1543     // nothing to do.
1544     unsigned parentPosInInput = invPermMap[permMap[i] - 1].second;
1545     if (i > 0 && static_cast<unsigned>(i - 1) == parentPosInInput)
1546       continue;
1547 
1548     // Move input[i] to its surrounding loop in the transformed nest.
1549     auto *destBody = input[parentPosInInput].getBody();
1550     destBody->getOperations().splice(destBody->begin(),
1551                                      input[i]->getBlock()->getOperations(),
1552                                      Block::iterator(input[i]));
1553   }
1554 
1555   return invPermMap[0].second;
1556 }
1557 
1558 // Sinks all sequential loops to the innermost levels (while preserving
1559 // relative order among them) and moves all parallel loops to the
1560 // outermost (while again preserving relative order among them).
1561 AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) {
1562   SmallVector<AffineForOp, 4> loops;
1563   getPerfectlyNestedLoops(loops, forOp);
1564   if (loops.size() < 2)
1565     return forOp;
1566 
1567   // Gather dependence components for dependences between all ops in loop nest
1568   // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
1569   unsigned maxLoopDepth = loops.size();
1570   std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
1571   getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
1572 
1573   // Mark loops as either parallel or sequential.
1574   SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
1575   for (auto &depComps : depCompsVec) {
1576     assert(depComps.size() >= maxLoopDepth);
1577     for (unsigned j = 0; j < maxLoopDepth; ++j) {
1578       DependenceComponent &depComp = depComps[j];
1579       assert(depComp.lb.hasValue() && depComp.ub.hasValue());
1580       if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0)
1581         isParallelLoop[j] = false;
1582     }
1583   }
1584 
1585   // Count the number of parallel loops.
1586   unsigned numParallelLoops = 0;
1587   for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i)
1588     if (isParallelLoop[i])
1589       ++numParallelLoops;
1590 
1591   // Compute permutation of loops that sinks sequential loops (and thus raises
1592   // parallel loops) while preserving relative order.
1593   SmallVector<unsigned, 4> loopPermMap(maxLoopDepth);
1594   unsigned nextSequentialLoop = numParallelLoops;
1595   unsigned nextParallelLoop = 0;
1596   for (unsigned i = 0; i < maxLoopDepth; ++i) {
1597     if (isParallelLoop[i]) {
1598       loopPermMap[i] = nextParallelLoop++;
1599     } else {
1600       loopPermMap[i] = nextSequentialLoop++;
1601     }
1602   }
1603 
1604   // Check if permutation 'loopPermMap' would violate dependences.
1605   if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap))
1606     return forOp;
1607   // Perform loop interchange according to permutation 'loopPermMap'.
1608   unsigned loopNestRootIndex = permuteLoops(loops, loopPermMap);
1609   return loops[loopNestRootIndex];
1610 }
1611 
1612 // Factors out common behavior to add a new `iv` (resp. `iv` + `offset`) to the
1613 // lower (resp. upper) loop bound. When called for both the lower and upper
1614 // bounds, the resulting IR resembles:
1615 //
1616 // ```mlir
1617 //    affine.for %i = max (`iv, ...) to min (`iv` + `offset`) {
1618 //      ...
1619 //    }
1620 // ```
1621 static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map,
1622                                 SmallVector<Value, 4> *operands,
1623                                 int64_t offset = 0) {
1624   auto bounds = llvm::to_vector<4>(map->getResults());
1625   bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset);
1626   operands->insert(operands->begin() + map->getNumDims(), iv);
1627   *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds,
1628                         b.getContext());
1629   canonicalizeMapAndOperands(map, operands);
1630 }
1631 
1632 // Stripmines `forOp` by `factor` and sinks it under each of the `targets`.
1633 // Stripmine-sink is a primitive building block for generalized tiling of
1634 // imperfectly nested loops.
1635 // This transformation is purely mechanical and does not check legality,
1636 // profitability or even structural correctness. It is the user's
1637 // responsibility to specify `targets` that are dominated by `forOp`.
1638 // Returns the new AffineForOps, one per `targets`, nested immediately under
1639 // each of the `targets`.
1640 static SmallVector<AffineForOp, 8>
1641 stripmineSink(AffineForOp forOp, uint64_t factor,
1642               ArrayRef<AffineForOp> targets) {
1643   auto originalStep = forOp.getStep();
1644   auto scaledStep = originalStep * factor;
1645   forOp.setStep(scaledStep);
1646 
1647   OpBuilder b(forOp->getBlock(), std::next(Block::iterator(forOp)));
1648 
1649   // Lower-bound map creation.
1650   auto lbMap = forOp.getLowerBoundMap();
1651   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1652   augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands);
1653 
1654   // Upper-bound map creation.
1655   auto ubMap = forOp.getUpperBoundMap();
1656   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1657   augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands,
1658                       /*offset=*/scaledStep);
1659 
1660   auto iv = forOp.getInductionVar();
1661   SmallVector<AffineForOp, 8> innerLoops;
1662   for (auto t : targets) {
1663     // Insert newForOp before the terminator of `t`.
1664     auto b = OpBuilder::atBlockTerminator(t.getBody());
1665     auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap,
1666                                           ubOperands, ubMap, originalStep);
1667     auto begin = t.getBody()->begin();
1668     // Skip terminator and `newForOp` which is just before the terminator.
1669     auto nOps = t.getBody()->getOperations().size() - 2;
1670     newForOp.getBody()->getOperations().splice(
1671         newForOp.getBody()->getOperations().begin(),
1672         t.getBody()->getOperations(), begin, std::next(begin, nOps));
1673     replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
1674                                newForOp.region());
1675     innerLoops.push_back(newForOp);
1676   }
1677 
1678   return innerLoops;
1679 }
1680 
1681 // Stripmines a `forOp` by `factor` and sinks it under a single `target`.
1682 // Returns the new AffineForOps, nested immediately under `target`.
1683 template <typename SizeType>
1684 static AffineForOp stripmineSink(AffineForOp forOp, SizeType factor,
1685                                  AffineForOp target) {
1686   // TODO: Use cheap structural assertions that targets are nested under
1687   // forOp and that targets are not nested under each other when DominanceInfo
1688   // exposes the capability. It seems overkill to construct a whole function
1689   // dominance tree at this point.
1690   auto res = stripmineSink(forOp, factor, ArrayRef<AffineForOp>(target));
1691   assert(res.size() == 1 && "Expected 1 inner forOp");
1692   return res[0];
1693 }
1694 
1695 SmallVector<SmallVector<AffineForOp, 8>, 8>
1696 mlir::tile(ArrayRef<AffineForOp> forOps, ArrayRef<uint64_t> sizes,
1697            ArrayRef<AffineForOp> targets) {
1698   SmallVector<SmallVector<AffineForOp, 8>, 8> res;
1699   SmallVector<AffineForOp, 8> currentTargets(targets.begin(), targets.end());
1700   for (auto it : llvm::zip(forOps, sizes)) {
1701     auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1702     res.push_back(step);
1703     currentTargets = step;
1704   }
1705   return res;
1706 }
1707 
1708 SmallVector<AffineForOp, 8> mlir::tile(ArrayRef<AffineForOp> forOps,
1709                                        ArrayRef<uint64_t> sizes,
1710                                        AffineForOp target) {
1711   SmallVector<AffineForOp, 8> res;
1712   for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
1713     assert(loops.size() == 1);
1714     res.push_back(loops[0]);
1715   }
1716   return res;
1717 }
1718 
1719 LogicalResult mlir::coalesceLoops(MutableArrayRef<AffineForOp> loops) {
1720   if (loops.size() < 2)
1721     return success();
1722 
1723   AffineForOp innermost = loops.back();
1724   AffineForOp outermost = loops.front();
1725   AffineBound ub = outermost.getUpperBound();
1726   AffineMap origUbMap = ub.getMap();
1727   Location loc = outermost.getLoc();
1728   OpBuilder builder(outermost);
1729   for (AffineForOp loop : loops) {
1730     // We only work on normalized loops.
1731     if (loop.getStep() != 1 || !loop.hasConstantLowerBound() ||
1732         loop.getConstantLowerBound() != 0)
1733       return failure();
1734   }
1735   SmallVector<Value, 4> upperBoundSymbols;
1736   SmallVector<Value, 4> ubOperands(ub.getOperands().begin(),
1737                                    ub.getOperands().end());
1738 
1739   // 1. Store the upper bound of the outermost loop in a variable.
1740   Value prev;
1741   if (!llvm::hasSingleElement(origUbMap.getResults()))
1742     prev = builder.create<AffineMinOp>(loc, origUbMap, ubOperands);
1743   else
1744     prev = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands);
1745   upperBoundSymbols.push_back(prev);
1746 
1747   // 2. Emit code computing the upper bound of the coalesced loop as product of
1748   // the number of iterations of all loops.
1749   for (AffineForOp loop : loops.drop_front()) {
1750     ub = loop.getUpperBound();
1751     origUbMap = ub.getMap();
1752     ubOperands = ub.getOperands();
1753     Value upperBound;
1754     // If upper bound map has more than one result, take their minimum.
1755     if (!llvm::hasSingleElement(origUbMap.getResults()))
1756       upperBound = builder.create<AffineMinOp>(loc, origUbMap, ubOperands);
1757     else
1758       upperBound = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands);
1759     upperBoundSymbols.push_back(upperBound);
1760     SmallVector<Value, 4> operands;
1761     operands.push_back(prev);
1762     operands.push_back(upperBound);
1763     // Maintain running product of loop upper bounds.
1764     prev = builder.create<AffineApplyOp>(
1765         loc,
1766         AffineMap::get(/*numDims=*/1,
1767                        /*numSymbols=*/1,
1768                        builder.getAffineDimExpr(0) *
1769                            builder.getAffineSymbolExpr(0)),
1770         operands);
1771   }
1772   // Set upper bound of the coalesced loop.
1773   AffineMap newUbMap = AffineMap::get(
1774       /*numDims=*/0,
1775       /*numSymbols=*/1, builder.getAffineSymbolExpr(0), builder.getContext());
1776   outermost.setUpperBound(prev, newUbMap);
1777 
1778   builder.setInsertionPointToStart(outermost.getBody());
1779 
1780   // 3. Remap induction variables. For each original loop, the value of the
1781   // induction variable can be obtained by dividing the induction variable of
1782   // the linearized loop by the total number of iterations of the loops nested
1783   // in it modulo the number of iterations in this loop (remove the values
1784   // related to the outer loops):
1785   //   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
1786   // Compute these iteratively from the innermost loop by creating a "running
1787   // quotient" of division by the range.
1788   Value previous = outermost.getInductionVar();
1789   for (unsigned idx = loops.size(); idx > 0; --idx) {
1790     if (idx != loops.size()) {
1791       SmallVector<Value, 4> operands;
1792       operands.push_back(previous);
1793       operands.push_back(upperBoundSymbols[idx]);
1794       previous = builder.create<AffineApplyOp>(
1795           loc,
1796           AffineMap::get(
1797               /*numDims=*/1, /*numSymbols=*/1,
1798               builder.getAffineDimExpr(0).floorDiv(
1799                   builder.getAffineSymbolExpr(0))),
1800           operands);
1801     }
1802     // Modified value of the induction variables of the nested loops after
1803     // coalescing.
1804     Value inductionVariable;
1805     if (idx == 1) {
1806       inductionVariable = previous;
1807     } else {
1808       SmallVector<Value, 4> applyOperands;
1809       applyOperands.push_back(previous);
1810       applyOperands.push_back(upperBoundSymbols[idx - 1]);
1811       inductionVariable = builder.create<AffineApplyOp>(
1812           loc,
1813           AffineMap::get(
1814               /*numDims=*/1, /*numSymbols=*/1,
1815               builder.getAffineDimExpr(0) % builder.getAffineSymbolExpr(0)),
1816           applyOperands);
1817     }
1818     replaceAllUsesInRegionWith(loops[idx - 1].getInductionVar(),
1819                                inductionVariable, loops.back().region());
1820   }
1821 
1822   // 4. Move the operations from the innermost just above the second-outermost
1823   // loop, delete the extra terminator and the second-outermost loop.
1824   AffineForOp secondOutermostLoop = loops[1];
1825   innermost.getBody()->back().erase();
1826   outermost.getBody()->getOperations().splice(
1827       Block::iterator(secondOutermostLoop.getOperation()),
1828       innermost.getBody()->getOperations());
1829   secondOutermostLoop.erase();
1830   return success();
1831 }
1832 
1833 void mlir::mapLoopToProcessorIds(scf::ForOp forOp, ArrayRef<Value> processorId,
1834                                  ArrayRef<Value> numProcessors) {
1835   assert(processorId.size() == numProcessors.size());
1836   if (processorId.empty())
1837     return;
1838 
1839   OpBuilder b(forOp);
1840   Location loc(forOp.getLoc());
1841   AffineExpr lhs, rhs;
1842   bindSymbols(forOp.getContext(), lhs, rhs);
1843   auto mulMap = AffineMap::get(0, 2, lhs * rhs);
1844   auto addMap = AffineMap::get(0, 2, lhs + rhs);
1845 
1846   Value linearIndex = processorId.front();
1847   for (unsigned i = 1, e = processorId.size(); i < e; ++i) {
1848     auto mulApplyOp = b.create<AffineApplyOp>(
1849         loc, mulMap, ValueRange{linearIndex, numProcessors[i]});
1850     linearIndex = b.create<AffineApplyOp>(
1851         loc, addMap, ValueRange{mulApplyOp, processorId[i]});
1852   }
1853 
1854   auto mulApplyOp = b.create<AffineApplyOp>(
1855       loc, mulMap, ValueRange{linearIndex, forOp.getStep()});
1856   Value lb = b.create<AffineApplyOp>(
1857       loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()});
1858   forOp.setLowerBound(lb);
1859 
1860   Value step = forOp.getStep();
1861   for (auto numProcs : numProcessors)
1862     step = b.create<AffineApplyOp>(loc, mulMap, ValueRange{numProcs, step});
1863   forOp.setStep(step);
1864 }
1865 
1866 /// Given a memref region, determine the lowest depth at which transfers can be
1867 /// placed for it, and return the corresponding block, start and end positions
1868 /// in the block for placing incoming (read) and outgoing (write) copies
1869 /// respectively. The lowest depth depends on whether the region being accessed
1870 /// is hoistable with respect to one or more immediately surrounding loops.
1871 static void
1872 findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
1873                              Block::iterator &begin, Block::iterator &end,
1874                              Block **copyPlacementBlock,
1875                              Block::iterator *copyInPlacementStart,
1876                              Block::iterator *copyOutPlacementStart) {
1877   const auto *cst = region.getConstraints();
1878   SmallVector<Value, 4> symbols;
1879   cst->getValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols);
1880 
1881   SmallVector<AffineForOp, 4> enclosingFors;
1882   getLoopIVs(*block.begin(), &enclosingFors);
1883   // Walk up loop parents till we find an IV on which this region is
1884   // symbolic/variant.
1885   auto it = enclosingFors.rbegin();
1886   for (auto e = enclosingFors.rend(); it != e; ++it) {
1887     // TODO: also need to be checking this for regions symbols that
1888     // aren't loop IVs, whether we are within their resp. defs' dominance scope.
1889     if (llvm::is_contained(symbols, it->getInductionVar()))
1890       break;
1891   }
1892 
1893   if (it != enclosingFors.rbegin()) {
1894     auto lastInvariantIV = *std::prev(it);
1895     *copyInPlacementStart = Block::iterator(lastInvariantIV.getOperation());
1896     *copyOutPlacementStart = std::next(*copyInPlacementStart);
1897     *copyPlacementBlock = lastInvariantIV->getBlock();
1898   } else {
1899     *copyInPlacementStart = begin;
1900     *copyOutPlacementStart = end;
1901     *copyPlacementBlock = &block;
1902   }
1903 }
1904 
1905 // Info comprising stride and number of elements transferred every stride.
1906 struct StrideInfo {
1907   int64_t stride;
1908   int64_t numEltPerStride;
1909 };
1910 
1911 /// Returns striding information for a copy/transfer of this region with
1912 /// potentially multiple striding levels from outermost to innermost. For an
1913 /// n-dimensional region, there can be at most n-1 levels of striding
1914 /// successively nested.
1915 //  TODO: make this work with non-identity layout maps.
1916 static void getMultiLevelStrides(const MemRefRegion &region,
1917                                  ArrayRef<int64_t> bufferShape,
1918                                  SmallVectorImpl<StrideInfo> *strideInfos) {
1919   if (bufferShape.size() <= 1)
1920     return;
1921 
1922   int64_t numEltPerStride = 1;
1923   int64_t stride = 1;
1924   for (int d = bufferShape.size() - 1; d >= 1; d--) {
1925     int64_t dimSize = region.memref.getType().cast<MemRefType>().getDimSize(d);
1926     stride *= dimSize;
1927     numEltPerStride *= bufferShape[d];
1928     // A stride is needed only if the region has a shorter extent than the
1929     // memref along the dimension *and* has an extent greater than one along the
1930     // next major dimension.
1931     if (bufferShape[d] < dimSize && bufferShape[d - 1] > 1) {
1932       strideInfos->push_back({stride, numEltPerStride});
1933     }
1934   }
1935 }
1936 
1937 /// Generates a point-wise copy from/to `memref' to/from `fastMemRef' and
1938 /// returns the outermost AffineForOp of the copy loop nest. `lbMaps` and
1939 /// `ubMaps` along with `lbOperands` and `ubOperands` hold the lower and upper
1940 /// bound information for the copy loop nest. `fastBufOffsets` contain the
1941 /// expressions to be subtracted out from the respective copy loop iterators in
1942 /// order to index the fast buffer. If `copyOut' is true, generates a copy-out;
1943 /// otherwise a copy-in. Builder `b` should be set to the point the copy nest is
1944 /// inserted.
1945 //
1946 /// The copy-in nest is generated as follows as an example for a 2-d region:
1947 /// for x = ...
1948 ///   for y = ...
1949 ///     fast_buf[x - offset_x][y - offset_y] = memref[x][y]
1950 ///
1951 static AffineForOp
1952 generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
1953                       ArrayRef<AffineMap> lbMaps, ArrayRef<Value> lbOperands,
1954                       ArrayRef<AffineMap> ubMaps, ArrayRef<Value> ubOperands,
1955                       ArrayRef<AffineExpr> fastBufOffsets, bool isCopyOut,
1956                       OpBuilder b) {
1957   assert(llvm::all_of(lbMaps, [&](AffineMap lbMap) {
1958     return lbMap.getNumInputs() == lbOperands.size();
1959   }));
1960   assert(llvm::all_of(ubMaps, [&](AffineMap ubMap) {
1961     return ubMap.getNumInputs() == ubOperands.size();
1962   }));
1963 
1964   unsigned rank = memref.getType().cast<MemRefType>().getRank();
1965   assert(lbMaps.size() == rank && "wrong number of lb maps");
1966   assert(ubMaps.size() == rank && "wrong number of ub maps");
1967 
1968   SmallVector<Value, 4> memIndices;
1969   SmallVector<AffineExpr, 4> fastBufExprs;
1970   SmallVector<Value, 4> fastBufMapOperands;
1971   AffineForOp copyNestRoot;
1972   SmallVector<AffineApplyOp, 4> mayBeDeadApplys;
1973   for (unsigned d = 0; d < rank; ++d) {
1974     auto forOp = createCanonicalizedAffineForOp(b, loc, lbOperands, lbMaps[d],
1975                                                 ubOperands, ubMaps[d]);
1976     if (d == 0)
1977       copyNestRoot = forOp;
1978 
1979     b = OpBuilder::atBlockTerminator(forOp.getBody());
1980 
1981     auto fastBufOffsetMap =
1982         AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]);
1983     auto offset = b.create<AffineApplyOp>(loc, fastBufOffsetMap, lbOperands);
1984 
1985     // Construct the subscript for the fast memref being copied into/from:
1986     // x - offset_x.
1987     fastBufExprs.push_back(b.getAffineDimExpr(2 * d + 1) -
1988                            b.getAffineDimExpr(2 * d));
1989     fastBufMapOperands.push_back(offset);
1990     fastBufMapOperands.push_back(forOp.getInductionVar());
1991     mayBeDeadApplys.push_back(offset);
1992 
1993     // Subscript for the slow memref being copied.
1994     memIndices.push_back(forOp.getInductionVar());
1995   }
1996 
1997   auto fastBufMap =
1998       AffineMap::get(2 * rank, /*symbolCount=*/0, fastBufExprs, b.getContext());
1999   fullyComposeAffineMapAndOperands(&fastBufMap, &fastBufMapOperands);
2000   fastBufMap = simplifyAffineMap(fastBufMap);
2001   canonicalizeMapAndOperands(&fastBufMap, &fastBufMapOperands);
2002 
2003   // Drop any dead affine.applys.
2004   for (auto applyOp : mayBeDeadApplys)
2005     if (applyOp.use_empty())
2006       applyOp.erase();
2007 
2008   if (!isCopyOut) {
2009     // Copy in.
2010     auto load = b.create<AffineLoadOp>(loc, memref, memIndices);
2011     b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufMap,
2012                             fastBufMapOperands);
2013     return copyNestRoot;
2014   }
2015 
2016   // Copy out.
2017   auto load =
2018       b.create<AffineLoadOp>(loc, fastMemRef, fastBufMap, fastBufMapOperands);
2019   b.create<AffineStoreOp>(loc, load, memref, memIndices);
2020   return copyNestRoot;
2021 }
2022 
2023 static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
2024 emitRemarkForBlock(Block &block) {
2025   return block.getParentOp()->emitRemark();
2026 }
2027 
2028 /// Creates a buffer in the faster memory space for the specified memref region;
2029 /// generates a copy from the lower memory space to this one, and replaces all
2030 /// loads/stores in the block range [`begin', `end') of `block' to load/store
2031 /// from that buffer. Returns failure if copies could not be generated due to
2032 /// yet unimplemented cases. `copyInPlacementStart` and `copyOutPlacementStart`
2033 /// in copyPlacementBlock specify the insertion points where the incoming copies
2034 /// and outgoing copies, respectively, should be inserted (the insertion happens
2035 /// right before the insertion point). Since `begin` can itself be invalidated
2036 /// due to the memref rewriting done from this method, the output argument
2037 /// `nBegin` is set to its replacement (set to `begin` if no invalidation
2038 /// happens). Since outgoing copies could have  been inserted at `end`, the
2039 /// output argument `nEnd` is set to the new end. `sizeInBytes` is set to the
2040 /// size of the fast buffer allocated.
2041 static LogicalResult generateCopy(
2042     const MemRefRegion &region, Block *block, Block::iterator begin,
2043     Block::iterator end, Block *copyPlacementBlock,
2044     Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart,
2045     AffineCopyOptions copyOptions, DenseMap<Value, Value> &fastBufferMap,
2046     DenseSet<Operation *> &copyNests, uint64_t *sizeInBytes,
2047     Block::iterator *nBegin, Block::iterator *nEnd) {
2048   *nBegin = begin;
2049   *nEnd = end;
2050 
2051   FuncOp f = begin->getParentOfType<FuncOp>();
2052   OpBuilder topBuilder(f.getBody());
2053   Value zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0);
2054 
2055   if (begin == end)
2056     return success();
2057 
2058   // Is the copy out point at the end of the block where we are doing
2059   // explicit copying.
2060   bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart);
2061 
2062   // Copies for read regions are going to be inserted at 'begin'.
2063   OpBuilder prologue(copyPlacementBlock, copyInPlacementStart);
2064   // Copies for write regions are going to be inserted at 'end'.
2065   OpBuilder epilogue(copyPlacementBlock, copyOutPlacementStart);
2066   OpBuilder &b = region.isWrite() ? epilogue : prologue;
2067 
2068   // Builder to create constants at the top level.
2069   auto func = copyPlacementBlock->getParent()->getParentOfType<FuncOp>();
2070   OpBuilder top(func.getBody());
2071 
2072   auto loc = region.loc;
2073   auto memref = region.memref;
2074   auto memRefType = memref.getType().cast<MemRefType>();
2075 
2076   if (!memRefType.getLayout().isIdentity()) {
2077     LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
2078     return failure();
2079   }
2080 
2081   // Indices to use for the copying.
2082   // Indices for the original memref being copied from/to.
2083   SmallVector<Value, 4> memIndices;
2084   // Indices for the faster buffer being copied into/from.
2085   SmallVector<Value, 4> bufIndices;
2086 
2087   unsigned rank = memRefType.getRank();
2088   SmallVector<int64_t, 4> fastBufferShape;
2089 
2090   // Compute the extents of the buffer.
2091   std::vector<SmallVector<int64_t, 4>> lbs;
2092   SmallVector<int64_t, 8> lbDivisors;
2093   lbs.reserve(rank);
2094   Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(
2095       &fastBufferShape, &lbs, &lbDivisors);
2096   if (!numElements.hasValue()) {
2097     LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
2098     return failure();
2099   }
2100 
2101   if (numElements.getValue() == 0) {
2102     LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
2103     *sizeInBytes = 0;
2104     return success();
2105   }
2106 
2107   SmallVector<AffineMap, 4> lbMaps(rank), ubMaps(rank);
2108   for (unsigned i = 0; i < rank; ++i)
2109     region.getLowerAndUpperBound(i, lbMaps[i], ubMaps[i]);
2110 
2111   const FlatAffineValueConstraints *cst = region.getConstraints();
2112   // 'regionSymbols' hold values that this memory region is symbolic/parametric
2113   // on; these typically include loop IVs surrounding the level at which the
2114   // copy generation is being done or other valid symbols in MLIR.
2115   SmallVector<Value, 8> regionSymbols;
2116   cst->getValues(rank, cst->getNumIds(), &regionSymbols);
2117 
2118   // Construct the index expressions for the fast memory buffer. The index
2119   // expression for a particular dimension of the fast buffer is obtained by
2120   // subtracting out the lower bound on the original memref's data region
2121   // along the corresponding dimension.
2122 
2123   // Index start offsets for faster memory buffer relative to the original.
2124   SmallVector<AffineExpr, 4> fastBufOffsets;
2125   fastBufOffsets.reserve(rank);
2126   for (unsigned d = 0; d < rank; d++) {
2127     assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
2128 
2129     AffineExpr offset = top.getAffineConstantExpr(0);
2130     for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++)
2131       offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
2132     assert(lbDivisors[d] > 0);
2133     offset =
2134         (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
2135 
2136     // Set copy start location for this dimension in the lower memory space
2137     // memref.
2138     if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
2139       auto indexVal = caf.getValue();
2140       if (indexVal == 0) {
2141         memIndices.push_back(zeroIndex);
2142       } else {
2143         memIndices.push_back(
2144             top.create<arith::ConstantIndexOp>(loc, indexVal).getResult());
2145       }
2146     } else {
2147       // The coordinate for the start location is just the lower bound along the
2148       // corresponding dimension on the memory region (stored in 'offset').
2149       auto map = AffineMap::get(
2150           cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset);
2151       memIndices.push_back(b.create<AffineApplyOp>(loc, map, regionSymbols));
2152     }
2153     // The fast buffer is copied into at location zero; addressing is relative.
2154     bufIndices.push_back(zeroIndex);
2155 
2156     // Record the offsets since they are needed to remap the memory accesses of
2157     // the original memref further below.
2158     fastBufOffsets.push_back(offset);
2159   }
2160 
2161   // The faster memory space buffer.
2162   Value fastMemRef;
2163 
2164   // Check if a buffer was already created.
2165   bool existingBuf = fastBufferMap.count(memref) > 0;
2166   if (!existingBuf) {
2167     AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank);
2168     auto fastMemRefType =
2169         MemRefType::get(fastBufferShape, memRefType.getElementType(),
2170                         fastBufferLayout, copyOptions.fastMemorySpace);
2171 
2172     // Create the fast memory space buffer just before the 'affine.for'
2173     // operation.
2174     fastMemRef =
2175         prologue.create<memref::AllocOp>(loc, fastMemRefType).getResult();
2176     // Record it.
2177     fastBufferMap[memref] = fastMemRef;
2178     // fastMemRefType is a constant shaped memref.
2179     *sizeInBytes = getMemRefSizeInBytes(fastMemRefType).getValue();
2180     LLVM_DEBUG(emitRemarkForBlock(*block)
2181                << "Creating fast buffer of type " << fastMemRefType
2182                << " and size " << llvm::divideCeil(*sizeInBytes, 1024)
2183                << " KiB\n");
2184   } else {
2185     // Reuse the one already created.
2186     fastMemRef = fastBufferMap[memref];
2187     *sizeInBytes = 0;
2188   }
2189 
2190   auto numElementsSSA =
2191       top.create<arith::ConstantIndexOp>(loc, numElements.getValue());
2192 
2193   Value dmaStride = nullptr;
2194   Value numEltPerDmaStride = nullptr;
2195   if (copyOptions.generateDma) {
2196     SmallVector<StrideInfo, 4> dmaStrideInfos;
2197     getMultiLevelStrides(region, fastBufferShape, &dmaStrideInfos);
2198 
2199     // TODO: use all stride levels once DmaStartOp is extended for
2200     // multi-level strides.
2201     if (dmaStrideInfos.size() > 1) {
2202       LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
2203       return failure();
2204     }
2205 
2206     if (!dmaStrideInfos.empty()) {
2207       dmaStride =
2208           top.create<arith::ConstantIndexOp>(loc, dmaStrideInfos[0].stride);
2209       numEltPerDmaStride = top.create<arith::ConstantIndexOp>(
2210           loc, dmaStrideInfos[0].numEltPerStride);
2211     }
2212   }
2213 
2214   // Record the last operation where we want the memref replacement to end. We
2215   // later do the memref replacement only in [begin, postDomFilter] so
2216   // that the original memref's used in the data movement code themselves don't
2217   // get replaced.
2218   auto postDomFilter = std::prev(end);
2219 
2220   // Create fully composed affine maps for each memref.
2221   auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
2222   fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
2223   auto bufAffineMap = b.getMultiDimIdentityMap(bufIndices.size());
2224   fullyComposeAffineMapAndOperands(&bufAffineMap, &bufIndices);
2225 
2226   if (!copyOptions.generateDma) {
2227     // Point-wise copy generation.
2228     auto copyNest =
2229         generatePointWiseCopy(loc, memref, fastMemRef, lbMaps,
2230                               /*lbOperands=*/regionSymbols, ubMaps,
2231                               /*ubOperands=*/regionSymbols, fastBufOffsets,
2232                               /*isCopyOut=*/region.isWrite(), b);
2233 
2234     // Record this so that we can skip it from yet another copy.
2235     copyNests.insert(copyNest);
2236 
2237     // Since new ops are being appended (for copy out's), adjust the end to
2238     // mark end of block range being processed if necessary.
2239     if (region.isWrite() && isCopyOutAtEndOfBlock)
2240       *nEnd = Block::iterator(copyNest.getOperation());
2241   } else {
2242     // DMA generation.
2243     // Create a tag (single element 1-d memref) for the DMA.
2244     auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {},
2245                                          copyOptions.tagMemorySpace);
2246     auto tagMemRef = prologue.create<memref::AllocOp>(loc, tagMemRefType);
2247 
2248     SmallVector<Value, 4> tagIndices({zeroIndex});
2249     auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
2250     fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
2251     if (!region.isWrite()) {
2252       // DMA non-blocking read from original buffer to fast buffer.
2253       b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
2254                                  fastMemRef, bufAffineMap, bufIndices,
2255                                  tagMemRef, tagAffineMap, tagIndices,
2256                                  numElementsSSA, dmaStride, numEltPerDmaStride);
2257     } else {
2258       // DMA non-blocking write from fast buffer to the original memref.
2259       auto op = b.create<AffineDmaStartOp>(
2260           loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
2261           memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
2262           dmaStride, numEltPerDmaStride);
2263       // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
2264       // end to mark end of block range being processed.
2265       if (isCopyOutAtEndOfBlock)
2266         *nEnd = Block::iterator(op.getOperation());
2267     }
2268 
2269     // Matching DMA wait to block on completion; tag always has a 0 index.
2270     b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex,
2271                               numElementsSSA);
2272 
2273     // Generate dealloc for the tag.
2274     auto tagDeallocOp = epilogue.create<memref::DeallocOp>(loc, tagMemRef);
2275     if (*nEnd == end && isCopyOutAtEndOfBlock)
2276       // Since new ops are being appended (for outgoing DMAs), adjust the end to
2277       // mark end of range of the original.
2278       *nEnd = Block::iterator(tagDeallocOp.getOperation());
2279   }
2280 
2281   // Generate dealloc for the buffer.
2282   if (!existingBuf) {
2283     auto bufDeallocOp = epilogue.create<memref::DeallocOp>(loc, fastMemRef);
2284     // When generating pointwise copies, `nEnd' has to be set to deallocOp on
2285     // the fast buffer (since it marks the new end insertion point).
2286     if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock)
2287       *nEnd = Block::iterator(bufDeallocOp.getOperation());
2288   }
2289 
2290   // Replace all uses of the old memref with the faster one while remapping
2291   // access indices (subtracting out lower bound offsets for each dimension).
2292   // Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT],
2293   // index remap will be (%i, %j) -> (%i - %iT, %j - %jT),
2294   // i.e., affine.apply (d0, d1, d2, d3) -> (d2-d0, d3-d1) (%iT, %jT, %i, %j),
2295   // and (%iT, %jT) will be the 'extraOperands' for 'rep all memref uses with'.
2296   // d2, d3 correspond to the original indices (%i, %j).
2297   SmallVector<AffineExpr, 4> remapExprs;
2298   remapExprs.reserve(rank);
2299   for (unsigned i = 0; i < rank; i++) {
2300     // The starting operands of indexRemap will be regionSymbols (the symbols on
2301     // which the memref region is parametric); then those corresponding to
2302     // the memref's original indices follow.
2303     auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i);
2304     remapExprs.push_back(dimExpr - fastBufOffsets[i]);
2305   }
2306   auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs,
2307                                    b.getContext());
2308 
2309   // Record the begin since it may be invalidated by memref replacement.
2310   Block::iterator prevOfBegin;
2311   bool isBeginAtStartOfBlock = (begin == block->begin());
2312   if (!isBeginAtStartOfBlock)
2313     prevOfBegin = std::prev(begin);
2314 
2315   // *Only* those uses within the range [begin, end) of 'block' are replaced.
2316   (void)replaceAllMemRefUsesWith(memref, fastMemRef,
2317                                  /*extraIndices=*/{}, indexRemap,
2318                                  /*extraOperands=*/regionSymbols,
2319                                  /*symbolOperands=*/{},
2320                                  /*domOpFilter=*/&*begin,
2321                                  /*postDomOpFilter=*/&*postDomFilter);
2322 
2323   *nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);
2324 
2325   return success();
2326 }
2327 
2328 /// Construct the memref region to just include the entire memref. Returns false
2329 /// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
2330 /// enclosing loop IVs of `op` (starting from the outermost) that the region
2331 /// is parametric on.
2332 static bool getFullMemRefAsRegion(Operation *op, unsigned numParamLoopIVs,
2333                                   MemRefRegion *region) {
2334   unsigned rank;
2335   if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
2336     rank = loadOp.getMemRefType().getRank();
2337     region->memref = loadOp.getMemRef();
2338     region->setWrite(false);
2339   } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
2340     rank = storeOp.getMemRefType().getRank();
2341     region->memref = storeOp.getMemRef();
2342     region->setWrite(true);
2343   } else {
2344     assert(false && "expected load or store op");
2345     return false;
2346   }
2347   auto memRefType = region->memref.getType().cast<MemRefType>();
2348   if (!memRefType.hasStaticShape())
2349     return false;
2350 
2351   auto *regionCst = region->getConstraints();
2352 
2353   // Just get the first numSymbols IVs, which the memref region is parametric
2354   // on.
2355   SmallVector<AffineForOp, 4> ivs;
2356   getLoopIVs(*op, &ivs);
2357   ivs.resize(numParamLoopIVs);
2358   SmallVector<Value, 4> symbols;
2359   extractForInductionVars(ivs, &symbols);
2360   regionCst->reset(rank, numParamLoopIVs, 0);
2361   regionCst->setValues(rank, rank + numParamLoopIVs, symbols);
2362 
2363   // Memref dim sizes provide the bounds.
2364   for (unsigned d = 0; d < rank; d++) {
2365     auto dimSize = memRefType.getDimSize(d);
2366     assert(dimSize > 0 && "filtered dynamic shapes above");
2367     regionCst->addBound(FlatAffineConstraints::LB, d, 0);
2368     regionCst->addBound(FlatAffineConstraints::UB, d, dimSize - 1);
2369   }
2370   return true;
2371 }
2372 
2373 LogicalResult mlir::affineDataCopyGenerate(Block::iterator begin,
2374                                            Block::iterator end,
2375                                            const AffineCopyOptions &copyOptions,
2376                                            Optional<Value> filterMemRef,
2377                                            DenseSet<Operation *> &copyNests) {
2378   if (begin == end)
2379     return success();
2380 
2381   assert(begin->getBlock() == std::prev(end)->getBlock() &&
2382          "Inconsistent block begin/end args");
2383   assert(end != end->getBlock()->end() && "end can't be the block terminator");
2384 
2385   Block *block = begin->getBlock();
2386 
2387   // Copies will be generated for this depth, i.e., symbolic in all loops
2388   // surrounding the this block range.
2389   unsigned copyDepth = getNestingDepth(&*begin);
2390 
2391   LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
2392                           << "\n");
2393   LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n");
2394   LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n");
2395 
2396   // List of memory regions to copy for. We need a map vector to have a
2397   // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
2398   // since the alloc's for example are identical except for the SSA id.
2399   SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> readRegions;
2400   SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> writeRegions;
2401 
2402   // Map from original memref's to the fast buffers that their accesses are
2403   // replaced with.
2404   DenseMap<Value, Value> fastBufferMap;
2405 
2406   // To check for errors when walking the block.
2407   bool error = false;
2408 
2409   // Walk this range of operations  to gather all memory regions.
2410   block->walk(begin, end, [&](Operation *opInst) {
2411     // Gather regions to allocate to buffers in faster memory space.
2412     if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
2413       if ((filterMemRef.hasValue() && filterMemRef != loadOp.getMemRef()) ||
2414           (loadOp.getMemRefType().getMemorySpaceAsInt() !=
2415            copyOptions.slowMemorySpace))
2416         return;
2417     } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
2418       if ((filterMemRef.hasValue() && filterMemRef != storeOp.getMemRef()) ||
2419           storeOp.getMemRefType().getMemorySpaceAsInt() !=
2420               copyOptions.slowMemorySpace)
2421         return;
2422     } else {
2423       // Neither load nor a store op.
2424       return;
2425     }
2426 
2427     // Compute the MemRefRegion accessed.
2428     auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
2429     if (failed(region->compute(opInst, copyDepth, /*sliceState=*/nullptr,
2430                                /*addMemRefDimBounds=*/false))) {
2431       LLVM_DEBUG(llvm::dbgs()
2432                  << "Error obtaining memory region: semi-affine maps?\n");
2433       LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
2434       if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
2435         LLVM_DEBUG(
2436             opInst->emitError("non-constant memref sizes not yet supported"));
2437         error = true;
2438         return;
2439       }
2440     }
2441 
2442     // Each memref has a single buffer associated with it irrespective of how
2443     // many load's and store's happen on it.
2444     // TODO: in the future, when regions don't intersect and satisfy
2445     // other properties (based on load/store regions), we could consider
2446     // multiple buffers per memref.
2447 
2448     // Add to the appropriate region if it's not already in it, or take a
2449     // bounding box union with the existing one if it's already in there.
2450     // Note that a memref may have both read and write regions - so update the
2451     // region in the other list if one exists (write in case of read and vice
2452     // versa) since there is a single bounding box for a memref across all reads
2453     // and writes that happen on it.
2454 
2455     // Attempts to update; returns true if 'region' exists in targetRegions.
2456     auto updateRegion =
2457         [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
2458                 &targetRegions) {
2459           const auto *const it = targetRegions.find(region->memref);
2460           if (it == targetRegions.end())
2461             return false;
2462 
2463           // Perform a union with the existing region.
2464           if (failed(it->second->unionBoundingBox(*region))) {
2465             LLVM_DEBUG(llvm::dbgs()
2466                        << "Memory region bounding box failed; "
2467                           "over-approximating to the entire memref\n");
2468             // If the union fails, we will overapproximate.
2469             if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
2470               LLVM_DEBUG(opInst->emitError(
2471                   "non-constant memref sizes not yet supported"));
2472               error = true;
2473               return true;
2474             }
2475             it->second->getConstraints()->clearAndCopyFrom(
2476                 *region->getConstraints());
2477           } else {
2478             // Union was computed and stored in 'it->second': copy to 'region'.
2479             region->getConstraints()->clearAndCopyFrom(
2480                 *it->second->getConstraints());
2481           }
2482           return true;
2483         };
2484 
2485     bool existsInRead = updateRegion(readRegions);
2486     if (error)
2487       return;
2488     bool existsInWrite = updateRegion(writeRegions);
2489     if (error)
2490       return;
2491 
2492     // Finally add it to the region list.
2493     if (region->isWrite() && !existsInWrite) {
2494       writeRegions[region->memref] = std::move(region);
2495     } else if (!region->isWrite() && !existsInRead) {
2496       readRegions[region->memref] = std::move(region);
2497     }
2498   });
2499 
2500   if (error) {
2501     LLVM_DEBUG(begin->emitError(
2502         "copy generation failed for one or more memref's in this block\n"));
2503     return failure();
2504   }
2505 
2506   uint64_t totalCopyBuffersSizeInBytes = 0;
2507   bool ret = true;
2508   auto processRegions =
2509       [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
2510               &regions) {
2511         for (const auto &regionEntry : regions) {
2512           // For each region, hoist copy in/out past all hoistable
2513           // 'affine.for's.
2514           Block::iterator copyInPlacementStart, copyOutPlacementStart;
2515           Block *copyPlacementBlock;
2516           findHighestBlockForPlacement(
2517               *regionEntry.second, *block, begin, end, &copyPlacementBlock,
2518               &copyInPlacementStart, &copyOutPlacementStart);
2519 
2520           uint64_t sizeInBytes;
2521           Block::iterator nBegin, nEnd;
2522           LogicalResult iRet = generateCopy(
2523               *regionEntry.second, block, begin, end, copyPlacementBlock,
2524               copyInPlacementStart, copyOutPlacementStart, copyOptions,
2525               fastBufferMap, copyNests, &sizeInBytes, &nBegin, &nEnd);
2526           if (succeeded(iRet)) {
2527             // begin/end could have been invalidated, and need update.
2528             begin = nBegin;
2529             end = nEnd;
2530             totalCopyBuffersSizeInBytes += sizeInBytes;
2531           }
2532           ret = ret & succeeded(iRet);
2533         }
2534       };
2535   processRegions(readRegions);
2536   processRegions(writeRegions);
2537 
2538   if (!ret) {
2539     LLVM_DEBUG(begin->emitError(
2540         "copy generation failed for one or more memref's in this block\n"));
2541     return failure();
2542   }
2543 
2544   // For a range of operations, a note will be emitted at the caller.
2545   AffineForOp forOp;
2546   if (llvm::DebugFlag && (forOp = dyn_cast<AffineForOp>(&*begin))) {
2547     LLVM_DEBUG(forOp.emitRemark()
2548                << llvm::divideCeil(totalCopyBuffersSizeInBytes, 1024)
2549                << " KiB of copy buffers in fast memory space for this block\n");
2550   }
2551 
2552   if (totalCopyBuffersSizeInBytes > copyOptions.fastMemCapacityBytes) {
2553     StringRef str = "Total size of all copy buffers' for this block "
2554                     "exceeds fast memory capacity\n";
2555     block->getParentOp()->emitWarning(str);
2556   }
2557 
2558   return success();
2559 }
2560 
2561 // A convenience version of affineDataCopyGenerate for all ops in the body of
2562 // an AffineForOp.
2563 LogicalResult mlir::affineDataCopyGenerate(AffineForOp forOp,
2564                                            const AffineCopyOptions &copyOptions,
2565                                            Optional<Value> filterMemRef,
2566                                            DenseSet<Operation *> &copyNests) {
2567   return affineDataCopyGenerate(forOp.getBody()->begin(),
2568                                 std::prev(forOp.getBody()->end()), copyOptions,
2569                                 filterMemRef, copyNests);
2570 }
2571 
2572 LogicalResult mlir::generateCopyForMemRegion(
2573     const MemRefRegion &memrefRegion, Operation *analyzedOp,
2574     const AffineCopyOptions &copyOptions, CopyGenerateResult &result) {
2575   Block *block = analyzedOp->getBlock();
2576   auto begin = analyzedOp->getIterator();
2577   auto end = std::next(begin);
2578   DenseMap<Value, Value> fastBufferMap;
2579   DenseSet<Operation *> copyNests;
2580 
2581   auto err = generateCopy(memrefRegion, block, begin, end, block, begin, end,
2582                           copyOptions, fastBufferMap, copyNests,
2583                           &result.sizeInBytes, &begin, &end);
2584   if (failed(err))
2585     return err;
2586 
2587   const auto &en = fastBufferMap.find(memrefRegion.memref);
2588   // In some cases (empty loops), no copy generation would have happened.
2589   if (en == fastBufferMap.end())
2590     return failure();
2591   result.alloc = en->second.getDefiningOp();
2592   assert(result.alloc && "fast buffer expected to be locally allocated");
2593   assert(copyNests.size() <= 1 && "At most one copy nest is expected.");
2594   result.copyNest = copyNests.empty() ? nullptr : *copyNests.begin();
2595   return success();
2596 }
2597 
2598 /// Gathers all AffineForOps in 'block' at 'currLoopDepth' in 'depthToLoops'.
2599 static void
2600 gatherLoopsInBlock(Block *block, unsigned currLoopDepth,
2601                    std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
2602   // Add a new empty level to output if it doesn't exist level already.
2603   assert(currLoopDepth <= depthToLoops.size() && "Unexpected currLoopDepth");
2604   if (currLoopDepth == depthToLoops.size())
2605     depthToLoops.emplace_back();
2606 
2607   for (auto &op : *block) {
2608     if (auto forOp = dyn_cast<AffineForOp>(op)) {
2609       depthToLoops[currLoopDepth].push_back(forOp);
2610       gatherLoopsInBlock(forOp.getBody(), currLoopDepth + 1, depthToLoops);
2611     }
2612   }
2613 }
2614 
2615 /// Gathers all AffineForOps in 'func.func' grouped by loop depth.
2616 void mlir::gatherLoops(FuncOp func,
2617                        std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
2618   for (auto &block : func)
2619     gatherLoopsInBlock(&block, /*currLoopDepth=*/0, depthToLoops);
2620 
2621   // Remove last loop level from output since it's empty.
2622   if (!depthToLoops.empty()) {
2623     assert(depthToLoops.back().empty() && "Last loop level is not empty?");
2624     depthToLoops.pop_back();
2625   }
2626 }
2627 
2628 // TODO: if necessary, this can be extended to also compose in any
2629 // affine.applys, fold to constant if all result dimensions of the map are
2630 // constant (canonicalizeMapAndOperands below already does this for single
2631 // result bound maps), and use simplifyMap to perform algebraic simplification.
2632 AffineForOp mlir::createCanonicalizedAffineForOp(
2633     OpBuilder b, Location loc, ValueRange lbOperands, AffineMap lbMap,
2634     ValueRange ubOperands, AffineMap ubMap, int64_t step) {
2635   SmallVector<Value, 4> lowerOperands(lbOperands);
2636   SmallVector<Value, 4> upperOperands(ubOperands);
2637 
2638   fullyComposeAffineMapAndOperands(&lbMap, &lowerOperands);
2639   canonicalizeMapAndOperands(&lbMap, &lowerOperands);
2640   lbMap = removeDuplicateExprs(lbMap);
2641   fullyComposeAffineMapAndOperands(&ubMap, &upperOperands);
2642   canonicalizeMapAndOperands(&ubMap, &upperOperands);
2643   ubMap = removeDuplicateExprs(ubMap);
2644 
2645   return b.create<AffineForOp>(loc, lowerOperands, lbMap, upperOperands, ubMap,
2646                                step);
2647 }
2648 
2649 /// Creates an AffineIfOp that encodes the conditional to choose between
2650 /// the constant trip count version and an unknown trip count version of this
2651 /// nest of loops. This is used to separate partial and full tiles if `loops`
2652 /// has the intra-tile loops. The affine.if op is inserted at the builder
2653 /// insertion point of `b`.
2654 static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops,
2655                                             OpBuilder b) {
2656   if (loops.empty())
2657     return nullptr;
2658 
2659   auto *context = loops[0].getContext();
2660 
2661   FlatAffineValueConstraints cst;
2662   SmallVector<Operation *, 8> ops;
2663   llvm::append_range(ops, loops);
2664   (void)getIndexSet(ops, &cst);
2665 
2666   // Remove constraints that are independent of these loop IVs.
2667   cst.removeIndependentConstraints(/*pos=*/0, /*num=*/loops.size());
2668 
2669   // Construct the constraint set representing the guard for full tiles. The
2670   // lower bound (and upper bound) corresponding to the full tile should be
2671   // larger (and resp. smaller) than any other lower (or upper bound).
2672   SmallVector<int64_t, 8> fullTileLb, fullTileUb;
2673   for (auto loop : loops) {
2674     (void)loop;
2675     // TODO: Non-unit stride is not an issue to generalize to.
2676     assert(loop.getStep() == 1 && "point loop step expected to be one");
2677     // Mark everything symbols for the purpose of finding a constant diff pair.
2678     cst.setDimSymbolSeparation(/*newSymbolCount=*/cst.getNumDimAndSymbolIds() -
2679                                1);
2680     unsigned fullTileLbPos, fullTileUbPos;
2681     if (!cst.getConstantBoundOnDimSize(0, /*lb=*/nullptr,
2682                                        /*boundFloorDivisor=*/nullptr,
2683                                        /*ub=*/nullptr, &fullTileLbPos,
2684                                        &fullTileUbPos)) {
2685       LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n");
2686       return nullptr;
2687     }
2688 
2689     SmallVector<unsigned, 4> lbIndices, ubIndices;
2690     cst.getLowerAndUpperBoundIndices(/*pos=*/0, &lbIndices, &ubIndices);
2691 
2692     auto fLb = cst.getInequality(fullTileLbPos);
2693     auto fUb = cst.getInequality(fullTileUbPos);
2694     fullTileLb.assign(fLb.begin(), fLb.end());
2695     fullTileUb.assign(fUb.begin(), fUb.end());
2696 
2697     // Full tile lower bound should be >= than any other lower bound.
2698     for (auto lbIndex : lbIndices)
2699       for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i)
2700         cst.atIneq(lbIndex, i) = fullTileLb[i] - cst.atIneq(lbIndex, i);
2701 
2702     // Full tile upper bound should be <= any other upper bound.
2703     for (auto ubIndex : ubIndices)
2704       for (unsigned i = 0, e = cst.getNumCols(); i < e; ++i)
2705         cst.atIneq(ubIndex, i) -= fullTileUb[i];
2706 
2707     cst.removeId(0);
2708   }
2709 
2710   // The previous step leads to all zeros for the full tile lb and ub position
2711   // itself; remove those and any other duplicates / trivial redundancies.
2712   cst.removeTrivialRedundancy();
2713 
2714   // Turn everything into dims conservatively since we earlier turned all
2715   // trailing ids past point loop IV into symbols. Some of these could be outer
2716   // loop IVs; we'll canonicalize anyway.
2717   cst.setDimSymbolSeparation(0);
2718 
2719   IntegerSet ifCondSet = cst.getAsIntegerSet(context);
2720   // ifCondSet can be null if cst was empty -- this can happen if all loops
2721   // in the nest have constant trip counts.
2722   if (!ifCondSet)
2723     return nullptr;
2724 
2725   SmallVector<Value, 4> setOperands;
2726   cst.getValues(0, cst.getNumDimAndSymbolIds(), &setOperands);
2727   canonicalizeSetAndOperands(&ifCondSet, &setOperands);
2728   return b.create<AffineIfOp>(loops[0].getLoc(), ifCondSet, setOperands,
2729                               /*withElseRegion=*/true);
2730 }
2731 
2732 /// Create the full tile loop nest (along with its body).
2733 static LogicalResult
2734 createFullTiles(MutableArrayRef<AffineForOp> inputNest,
2735                 SmallVectorImpl<AffineForOp> &fullTileLoops, OpBuilder b) {
2736   fullTileLoops.reserve(inputNest.size());
2737 
2738   // For each loop in the original nest identify a lower/upper bound pair such
2739   // that their difference is a constant.
2740   FlatAffineValueConstraints cst;
2741   for (auto loop : inputNest) {
2742     // TODO: straightforward to generalize to a non-unit stride.
2743     if (loop.getStep() != 1) {
2744       LLVM_DEBUG(llvm::dbgs()
2745                  << "[tile separation] non-unit stride not implemented\n");
2746       return failure();
2747     }
2748     SmallVector<Operation *, 1> loopOp{loop.getOperation()};
2749     (void)getIndexSet(loopOp, &cst);
2750     // We will mark everything other than this loop IV as symbol for getting a
2751     // pair of <lb, ub> with a constant difference.
2752     cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - 1);
2753     unsigned lbPos, ubPos;
2754     if (!cst.getConstantBoundOnDimSize(/*pos=*/0, /*lb=*/nullptr,
2755                                        /*lbDivisor=*/nullptr, /*ub=*/nullptr,
2756                                        &lbPos, &ubPos) ||
2757         lbPos == ubPos) {
2758       LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / "
2759                                  "equalities not yet handled\n");
2760       return failure();
2761     }
2762 
2763     // Set all identifiers as dimensions uniformly since some of those marked as
2764     // symbols above could be outer loop IVs (corresponding tile space IVs).
2765     cst.setDimSymbolSeparation(/*newSymbolCount=*/0);
2766 
2767     AffineValueMap lbVmap, ubVmap;
2768     cst.getIneqAsAffineValueMap(/*pos=*/0, lbPos, lbVmap, b.getContext());
2769     cst.getIneqAsAffineValueMap(/*pos=*/0, ubPos, ubVmap, b.getContext());
2770     AffineForOp fullTileLoop = createCanonicalizedAffineForOp(
2771         b, loop.getLoc(), lbVmap.getOperands(), lbVmap.getAffineMap(),
2772         ubVmap.getOperands(), ubVmap.getAffineMap());
2773     b = OpBuilder::atBlockTerminator(fullTileLoop.getBody());
2774     fullTileLoops.push_back(fullTileLoop);
2775   }
2776 
2777   // Add the body for the full tile loop nest.
2778   BlockAndValueMapping operandMap;
2779   for (const auto &loopEn : llvm::enumerate(inputNest))
2780     operandMap.map(loopEn.value().getInductionVar(),
2781                    fullTileLoops[loopEn.index()].getInductionVar());
2782   b = OpBuilder::atBlockTerminator(fullTileLoops.back().getBody());
2783   for (auto &op : inputNest.back().getBody()->without_terminator())
2784     b.clone(op, operandMap);
2785   return success();
2786 }
2787 
2788 LogicalResult
2789 mlir::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
2790                         SmallVectorImpl<AffineForOp> *fullTileNest) {
2791   if (inputNest.empty())
2792     return success();
2793 
2794   auto firstLoop = inputNest[0];
2795 
2796   // Each successive for op has to be nested in the other.
2797   auto prevLoop = firstLoop;
2798   for (auto loop : inputNest.drop_front(1)) {
2799     assert(loop->getParentOp() == prevLoop && "input not contiguously nested");
2800     prevLoop = loop;
2801   }
2802 
2803   // Create the full tile loop nest.
2804   SmallVector<AffineForOp, 4> fullTileLoops;
2805   OpBuilder b(firstLoop);
2806   if (failed(createFullTiles(inputNest, fullTileLoops, b))) {
2807     if (!fullTileLoops.empty())
2808       fullTileLoops.front().erase();
2809     return failure();
2810   }
2811 
2812   // Create and insert the version select right before the root of the nest.
2813   b = OpBuilder(firstLoop);
2814   AffineIfOp ifOp = createSeparationCondition(inputNest, b);
2815   if (!ifOp) {
2816     fullTileLoops.front().erase();
2817     LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating "
2818                                "separation condition\n");
2819     return failure();
2820   }
2821 
2822   // Move the full tile into the then block.
2823   Block *thenBlock = ifOp.getThenBlock();
2824   AffineForOp outermostFullTileLoop = fullTileLoops[0];
2825   thenBlock->getOperations().splice(
2826       std::prev(thenBlock->end()),
2827       outermostFullTileLoop->getBlock()->getOperations(),
2828       Block::iterator(outermostFullTileLoop));
2829 
2830   // Move the partial tile into the else block. The partial tile is the same as
2831   // the original loop nest.
2832   Block *elseBlock = ifOp.getElseBlock();
2833   elseBlock->getOperations().splice(std::prev(elseBlock->end()),
2834                                     firstLoop->getBlock()->getOperations(),
2835                                     Block::iterator(firstLoop));
2836 
2837   if (fullTileNest)
2838     *fullTileNest = std::move(fullTileLoops);
2839 
2840   return success();
2841 }
2842