1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/SCF/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Dialect/Vector/VectorUtils.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "mlir/Transforms/LoopUtils.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
28 
29 using llvm::dbgs;
30 
31 #define DEBUG_TYPE "linalg-hoisting"
32 
33 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 namespace {
39 /// Represents a unit of hoistable TransferWriteOp. This may comprise other
40 /// instructions that need to be hoisted too.
41 struct HoistableWrite {
42   vector::TransferWriteOp transferWriteOp;
43   SubTensorInsertOp subTensorInsertOp;
44 };
45 /// Represents a unit of hoistable TransferReadOp. This may comprise other
46 /// instructions that need to be hoisted too.
47 struct HoistableRead {
48   vector::TransferReadOp transferReadOp;
49   SubTensorOp subTensorOp;
50 };
51 } // namespace
52 
53 /// Return true if op1 and op2 are the same constant or the same SSA value.
54 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
55   auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
56     Attribute attr = ofr.dyn_cast<Attribute>();
57     // Note: isa+cast-like pattern allows writing the condition below as 1 line.
58     if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
59       attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
60     if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
61       return intAttr.getValue().getSExtValue();
62     return llvm::None;
63   };
64   auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
65   if (cst1 && cst2 && *cst1 == *cst2)
66     return true;
67   auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
68   return v1 && v2 && v1 == v2;
69 }
70 
71 /// Return true is all offsets, sizes and strides are equal.
72 static bool sameOffsetsSizesAndStrides(SubTensorOp s, SubTensorInsertOp si) {
73   if (s.static_offsets().size() != si.static_offsets().size())
74     return false;
75   if (s.static_sizes().size() != si.static_sizes().size())
76     return false;
77   if (s.static_strides().size() != si.static_strides().size())
78     return false;
79   for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
80     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
81       return false;
82   for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
83     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
84       return false;
85   for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
86     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
87       return false;
88   return true;
89 }
90 
91 /// Look for a HoistableRead, in the given tensor uses, accessing the same
92 /// offset as the HoistableWrite.
93 static HoistableRead findMatchingTransferRead(HoistableWrite write,
94                                               Value srcTensor) {
95   assert(write.transferWriteOp &&
96          "expected hoistable write to have a .transfer_write");
97 
98   LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
99                     << *write.transferWriteOp.getOperation() << "\n");
100   if (write.subTensorInsertOp)
101     LLVM_DEBUG(DBGS() << "findMatchingTransferRead subTensorInsertOp: "
102                       << *write.subTensorInsertOp.getOperation() << "\n");
103 
104   for (Operation *user : srcTensor.getUsers()) {
105     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
106                       << "\n");
107 
108     // If HoistableWrite involves a SubTensorInsertOp, we need to find a
109     // matching SubTensorOp.
110     SubTensorOp subTensorOp;
111     Operation *maybeTransferReadUser = user;
112     if (write.subTensorInsertOp) {
113       subTensorOp = dyn_cast<SubTensorOp>(user);
114       if (!subTensorOp || subTensorOp.getResult().getType() !=
115                               write.subTensorInsertOp.source().getType())
116         continue;
117 
118       LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
119                         << *subTensorOp << " vs " << *write.subTensorInsertOp
120                         << "\n");
121       if (!sameOffsetsSizesAndStrides(subTensorOp, write.subTensorInsertOp))
122         continue;
123 
124       LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
125       // If we got here, subTensorOp is hoistable iff it has exactly 2 uses:
126       //   1. the transfer_write we want to hoist.
127       //   2. a matching transfer_read.
128       // Anything else, we skip.
129       bool skip = false;
130       Operation *otherUser = nullptr;
131       for (Operation *u : subTensorOp->getUsers()) {
132         if (u == write.transferWriteOp)
133           continue;
134         if (otherUser) {
135           skip = true;
136           break;
137         }
138         otherUser = u;
139       }
140       if (skip || !otherUser)
141         continue;
142       maybeTransferReadUser = otherUser;
143     }
144 
145     LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
146                       << "\n");
147     auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
148     if (read && read.indices() == write.transferWriteOp.indices() &&
149         read.getVectorType() == write.transferWriteOp.getVectorType())
150       return HoistableRead{read, subTensorOp};
151   }
152   return HoistableRead();
153 }
154 
155 /// Check if the chunk of data inserted by the HoistableWrite are read by any
156 /// other op than the HoistableRead candidate.
157 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
158                                            HoistableRead candidateRead,
159                                            BlockArgument tensorArg) {
160   // Make sure none of the other uses read the part of the tensor modified
161   // by the transfer_write.
162   llvm::SmallVector<Value::use_range, 1> uses;
163   uses.push_back(tensorArg.getUses());
164   while (!uses.empty()) {
165     for (OpOperand &use : uses.pop_back_val()) {
166       Operation *user = use.getOwner();
167       // Skip the candidate use, only inspect the "other" uses.
168       if (user == candidateRead.transferReadOp ||
169           user == candidateRead.subTensorOp || user == write.transferWriteOp ||
170           user == write.subTensorInsertOp)
171         continue;
172       // Consider all transitive uses through a subtensor / subtensor_insert.
173       // TODO: atm we just bail because a stronger analysis is needed for these
174       // cases.
175       if (isa<SubTensorOp, SubTensorInsertOp>(user))
176         return true;
177       // Consider all transitive uses through a vector.transfer_write.
178       if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
179         uses.push_back(writeUser->getResult(0).getUses());
180         continue;
181       }
182       // Consider all nested uses through an scf::ForOp. We may have
183       // pass-through tensor arguments left from previous level of
184       // hoisting.
185       if (auto forUser = dyn_cast<scf::ForOp>(user)) {
186         Value arg = forUser.getLoopBody().getArgument(
187             use.getOperandNumber() - forUser.getNumControlOperands() +
188             /*iv value*/ 1);
189         uses.push_back(arg.getUses());
190         continue;
191       }
192       // Follow the use yield as long as it doesn't escape the original
193       // region.
194       scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
195       if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
196                            yieldUser->getParentOp())) {
197         Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
198         uses.push_back(ret.getUses());
199         continue;
200       }
201       auto read = dyn_cast<vector::TransferReadOp>(user);
202       if (!read || !isDisjointTransferIndices(
203                        cast<VectorTransferOpInterface>(read.getOperation()),
204                        cast<VectorTransferOpInterface>(
205                            write.transferWriteOp.getOperation()))) {
206         return true;
207       }
208     }
209   }
210   return false;
211 }
212 
213 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
214 /// Return the null HoistableWrite() if it is not comprised of a
215 /// vector.transfer_write + optional subtensor_insert or if any of the indexings
216 /// is `forOp`-dependent.
217 static HoistableWrite
218 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
219                                         OpOperand &yieldOperand) {
220   Value v = yieldOperand.get();
221   if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
222     // Indexing must not depend on `forOp`.
223     for (Value operand : write.indices())
224       if (!forOp.isDefinedOutsideOfLoop(operand))
225         return HoistableWrite();
226 
227     return HoistableWrite{write, nullptr};
228   }
229 
230   if (auto subTensorInsertOp = v.getDefiningOp<SubTensorInsertOp>()) {
231     // Inserted subTensor must come from vector.transfer_write.
232     auto write =
233         subTensorInsertOp.source().getDefiningOp<vector::TransferWriteOp>();
234     if (!write)
235       return HoistableWrite();
236 
237     // Tensor inserted into must be a BBArg at position matching yieldOperand's.
238     auto bbArg = subTensorInsertOp.dest().dyn_cast<BlockArgument>();
239     if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
240         bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
241       return HoistableWrite();
242 
243     // Indexing inserted into must not depend on `forOp`.
244     for (Value operand : subTensorInsertOp->getOperands().drop_front(
245              SubTensorInsertOp::getOffsetSizeAndStrideStartOperandIndex()))
246       if (!forOp.isDefinedOutsideOfLoop(operand))
247         return HoistableWrite();
248 
249     return HoistableWrite{write, subTensorInsertOp};
250   }
251 
252   return HoistableWrite();
253 }
254 
255 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
256 static void hoistReadWrite(HoistableRead read, HoistableWrite write,
257                            BlockArgument tensorBBArg) {
258   scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
259   assert(read.transferReadOp && write.transferWriteOp &&
260          "expected transfer_read and transfer_write ops to be set");
261   assert(((read.subTensorOp && write.subTensorInsertOp) ||
262           (!read.subTensorOp && !write.subTensorInsertOp)) &&
263          "expected matching subtensor / subtensor_insert");
264   LLVM_DEBUG(DBGS() << "In forOp:\n"
265                     << *forOp.getOperation()
266                     << "\nHoist: " << *read.transferReadOp.getOperation()
267                     << "\nHoist: " << *write.transferWriteOp.getOperation()
268                     << "\nInvolving: " << tensorBBArg << "\n");
269 
270   // If a read subtensor is present, hoist it.
271   if (read.subTensorOp && failed(forOp.moveOutOfLoop({read.subTensorOp})))
272     llvm_unreachable("Unexpected failure moving subtensor out of loop");
273 
274   // Hoist the transfer_read op.
275   if (failed(forOp.moveOutOfLoop({read.transferReadOp})))
276     llvm_unreachable("Unexpected failure moving transfer read out of loop");
277 
278   // TODO: don't hardcode /*numIvs=*/1.
279   assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
280   unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
281 
282   // Update the source tensor.
283   if (read.subTensorOp)
284     read.subTensorOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
285   else
286     read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
287 
288   // Hoist write after.
289   if (write.subTensorInsertOp)
290     write.subTensorInsertOp->moveAfter(forOp);
291   write.transferWriteOp->moveAfter(forOp);
292 
293   // Update the yield.
294   auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
295   if (write.subTensorInsertOp)
296     yieldOp->setOperand(initArgNumber, write.subTensorInsertOp.dest());
297   else
298     yieldOp->setOperand(initArgNumber, write.transferWriteOp.source());
299 
300   // Rewrite `loop` with additional new yields.
301   OpBuilder b(read.transferReadOp);
302   auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(),
303                                      write.transferWriteOp.vector());
304   // Transfer write has been hoisted, need to update the vector and tensor
305   // source. Replace the result of the loop to use the new tensor created
306   // outside the loop.
307   // Depending on whether a subtensor_insert is present or not, it carries the
308   // update on the tensor operands.
309   if (write.subTensorInsertOp) {
310     newForOp.getResult(initArgNumber)
311         .replaceAllUsesWith(write.subTensorInsertOp.getResult());
312     write.transferWriteOp.sourceMutable().assign(read.subTensorOp.result());
313     write.subTensorInsertOp.destMutable().assign(read.subTensorOp.source());
314   } else {
315     newForOp.getResult(initArgNumber)
316         .replaceAllUsesWith(write.transferWriteOp.getResult(0));
317     write.transferWriteOp.sourceMutable().assign(
318         newForOp.getResult(initArgNumber));
319   }
320 
321   // Always update with the newly yield tensor and vector.
322   write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back());
323 }
324 
325 // To hoist transfer op on tensor the logic can be significantly simplified
326 // compared to the case on buffer. The transformation follows this logic:
327 // 1. Look for transfer_write with a single use from ForOp yield
328 // 2. Check the uses of the matching block argument and look for a transfer_read
329 // with the same indices.
330 // 3. Check that all the other uses of the tensor argument are either disjoint
331 // tensor_read or transfer_write. For transfer_write uses recurse to make sure
332 // the new tensor has the same restrictions on its uses.
333 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
334 // After this transformation the scf.forOp may have unused arguments that can be
335 // remove by the canonicalization pass.
336 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
337   bool changed = true;
338   while (changed) {
339     changed = false;
340     func.walk([&](scf::ForOp forOp) {
341       Operation *yield = forOp.getBody()->getTerminator();
342       for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
343         OpOperand &ret = yield->getOpOperand(it.index());
344         HoistableWrite write =
345             getLoopInvariantTransferWriteOpDefining(forOp, ret);
346         if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
347           continue;
348         LLVM_DEBUG(dbgs() << "\n";
349                    DBGS() << "Candidate write for hoisting: "
350                           << *write.transferWriteOp.getOperation() << "\n");
351         if (write.subTensorInsertOp)
352           LLVM_DEBUG(DBGS() << "Candidate subtensor_insert for hoisting: "
353                             << *write.subTensorInsertOp.getOperation() << "\n");
354         if (llvm::any_of(write.transferWriteOp.indices(),
355                          [&forOp](Value index) {
356                            return !forOp.isDefinedOutsideOfLoop(index);
357                          }))
358           continue;
359         // Find a read with the same type and indices.
360         HoistableRead matchingRead =
361             findMatchingTransferRead(write, it.value());
362         // Make sure none of the other uses read the part of the tensor modified
363         // by the transfer_write.
364         if (!matchingRead.transferReadOp ||
365             tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
366           continue;
367 
368         LLVM_DEBUG(DBGS() << "Start hoisting\n");
369         hoistReadWrite(matchingRead, write, it.value());
370         changed = true;
371         forOp.erase();
372 
373         // Need to interrupt and restart: erasing the loop messes up the walk.
374         return WalkResult::interrupt();
375       }
376       return WalkResult::advance();
377     });
378     // Apply canonicalization so the newForOp + yield folds immediately, thus
379     // cleaning up the IR and potentially enabling more hoisting.
380     if (changed) {
381       RewritePatternSet patterns(func->getContext());
382       scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
383       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
384     }
385   }
386 }
387 
388 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
389   bool changed = true;
390   while (changed) {
391     changed = false;
392 
393     func.walk([&](vector::TransferReadOp transferRead) {
394       if (!transferRead.getShapedType().isa<MemRefType>())
395         return WalkResult::advance();
396 
397       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
398                         << *transferRead.getOperation() << "\n");
399       auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
400       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
401                         << "\n");
402       if (!loop)
403         return WalkResult::advance();
404 
405       if (failed(moveLoopInvariantCode(
406               cast<LoopLikeOpInterface>(loop.getOperation()))))
407         llvm_unreachable(
408             "Unexpected failure to move invariant code out of loop");
409 
410       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
411                         << "\n");
412 
413       llvm::SetVector<Operation *> forwardSlice;
414       getForwardSlice(transferRead.getOperation(), &forwardSlice);
415 
416       // Look for the last TransferWriteOp in the forwardSlice of
417       // `transferRead` that operates on the same memref.
418       vector::TransferWriteOp transferWrite;
419       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
420         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
421         if (!candidateWrite || candidateWrite.source() != transferRead.source())
422           continue;
423         transferWrite = candidateWrite;
424       }
425 
426       // All operands of the TransferRead must be defined outside of the loop.
427       for (auto operand : transferRead.getOperands())
428         if (!loop.isDefinedOutsideOfLoop(operand))
429           return WalkResult::advance();
430 
431       // Only hoist transfer_read / transfer_write pairs for now.
432       if (!transferWrite)
433         return WalkResult::advance();
434 
435       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
436                         << "\n");
437 
438       // Approximate aliasing by checking that:
439       //   1. indices are the same,
440       //   2. no other operations in the loop access the same memref except
441       //      for transfer_read/transfer_write accessing statically disjoint
442       //      slices.
443       if (transferRead.indices() != transferWrite.indices() &&
444           transferRead.getVectorType() == transferWrite.getVectorType())
445         return WalkResult::advance();
446 
447       // TODO: may want to memoize this information for performance but it
448       // likely gets invalidated often.
449       DominanceInfo dom(loop);
450       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
451         return WalkResult::advance();
452       for (auto &use : transferRead.source().getUses()) {
453         if (!dom.properlyDominates(loop, use.getOwner()))
454           continue;
455         if (use.getOwner() == transferRead.getOperation() ||
456             use.getOwner() == transferWrite.getOperation())
457           continue;
458         if (auto transferWriteUse =
459                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
460           if (!isDisjointTransferSet(
461                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
462                   cast<VectorTransferOpInterface>(
463                       transferWriteUse.getOperation())))
464             return WalkResult::advance();
465         } else if (auto transferReadUse =
466                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
467           if (!isDisjointTransferSet(
468                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
469                   cast<VectorTransferOpInterface>(
470                       transferReadUse.getOperation())))
471             return WalkResult::advance();
472         } else {
473           // Unknown use, we cannot prove that it doesn't alias with the
474           // transferRead/transferWrite operations.
475           return WalkResult::advance();
476         }
477       }
478 
479       // Hoist read before.
480       if (failed(loop.moveOutOfLoop({transferRead})))
481         llvm_unreachable(
482             "Unexpected failure to move transfer read out of loop");
483 
484       // Hoist write after.
485       transferWrite->moveAfter(loop);
486 
487       // Rewrite `loop` with new yields by cloning and erase the original loop.
488       OpBuilder b(transferRead);
489       auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
490                                          transferWrite.vector());
491 
492       // Transfer write has been hoisted, need to update the written value to
493       // the value yielded by the newForOp.
494       transferWrite.vector().replaceAllUsesWith(
495           newForOp.getResults().take_back()[0]);
496 
497       changed = true;
498       loop.erase();
499       // Need to interrupt and restart because erasing the loop messes up the
500       // walk.
501       return WalkResult::interrupt();
502     });
503   }
504 }
505 
506 /// Ensure prerequisites that guarantee pad op hoisting can occur.
507 /// Return failure in the cases when we cannot perform hoisting; i.e. if either:
508 ///   1. There exists a use of `padTensorOp` that is not a linalg input operand.
509 ///   2. There isn't an enclosing `outermostEnclosingForOp` loop.
510 ///   3. There exists an op with a region that is dominated by
511 ///   `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a
512 ///    LinalgOp.
513 ///   3. There exists an op with side effects that is dominated by
514 ///    `outermostEnclosingForOp` and that isn't a LoopLikeInterface.
515 ///
516 /// While ensuring prerequisites:
517 ///   1. Fill the `backwardSlice` to contain the topologically sorted ops
518 ///   dominated by `outermostEnclosingForOp`.
519 ///   2. Fill the `packingLoops` to contain only the enclosing loops of
520 ///   `backwardSlice` whose IV is actually used in computing padding. Loops that
521 ///   remain in `backwardSlice` but that are not in `packingLoops` are
522 ///   dimensions of reuse.
523 static LogicalResult
524 hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels,
525                                    llvm::SetVector<Operation *> &backwardSlice,
526                                    llvm::SetVector<Operation *> &packingLoops) {
527   // Bail on any use that isn't an input of a Linalg op.
528   // Hoisting of inplace updates happens after vectorization.
529   for (OpOperand &use : padTensorOp.result().getUses()) {
530     auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
531     if (!linalgUser || !linalgUser.isInputTensor(&use))
532       return failure();
533   }
534 
535   // Get at most nLevels of enclosing loops.
536   SmallVector<LoopLikeOpInterface> reverseEnclosingLoops;
537   Operation *outermostEnclosingForOp = nullptr,
538             *nextEnclosingForOp =
539                 padTensorOp->getParentOfType<LoopLikeOpInterface>();
540   while (nLevels-- > 0 && nextEnclosingForOp) {
541     outermostEnclosingForOp = nextEnclosingForOp;
542     reverseEnclosingLoops.push_back(outermostEnclosingForOp);
543     nextEnclosingForOp =
544         nextEnclosingForOp->getParentOfType<LoopLikeOpInterface>();
545   }
546   if (!outermostEnclosingForOp)
547     return failure();
548 
549   // Get the backwards slice from `padTensorOp` that is dominated by the
550   // outermost enclosing loop.
551   DominanceInfo domInfo(outermostEnclosingForOp);
552   getBackwardSlice(padTensorOp.getOperation(), &backwardSlice,
553                    [&](Operation *op) {
554                      return domInfo.dominates(outermostEnclosingForOp, op);
555                    });
556 
557   // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp.
558   if (llvm::any_of(backwardSlice, [](Operation *op) {
559         return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) &&
560                !isa<LinalgOp>(op);
561       }))
562     return failure();
563 
564   // Filter out the loops whose induction variable is not used to compute the
565   // padded result. As a first approximation, just look for IVs that have no use
566   // in the backwardSlice.
567   // These are the dimensions of reuse that we can exploit to reduce the amount
568   // of work / memory.
569   // TODO: would this optimization compose better as a canonicalization?
570   for (LoopLikeOpInterface loop : llvm::reverse(reverseEnclosingLoops)) {
571     auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
572     if (!forOp)
573       continue;
574     for (Operation *user : forOp.getInductionVar().getUsers()) {
575       if (backwardSlice.contains(user)) {
576         packingLoops.insert(forOp);
577         break;
578       }
579     }
580   }
581 
582   // Backward slice is a topologically sorted list of ops starting at
583   // `outermostEnclosingForOp`.
584   assert(outermostEnclosingForOp == backwardSlice.front());
585 
586   return success();
587 }
588 
589 /// Return the number of iterations in the loop (ub - lb).ceilDiv(step).
590 static Value buildLoopTripCount(OpBuilder &b, scf::ForOp forOp) {
591   MLIRContext *ctx = forOp->getContext();
592   AffineExpr lb, ub, step;
593   bindDims(ctx, lb, ub);
594   bindSymbols(ctx, step);
595   return b.create<AffineApplyOp>(
596       forOp->getLoc(), AffineMap::get(2, 1, {(ub - lb).ceilDiv(step)}, ctx),
597       ValueRange{forOp.lowerBound(), forOp.upperBound(), forOp.step()});
598 }
599 
600 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
601 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp forOp) {
602   MLIRContext *ctx = forOp->getContext();
603   AffineExpr iv, lb, step;
604   bindDims(ctx, iv, lb);
605   bindSymbols(ctx, step);
606   return b.create<AffineApplyOp>(
607       forOp->getLoc(), AffineMap::get(2, 1, {(iv - lb).ceilDiv(step)}, ctx),
608       ValueRange{forOp.getInductionVar(), forOp.lowerBound(), forOp.step()});
609 }
610 
611 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
612                                                   unsigned nLoops) {
613   llvm::SetVector<Operation *> backwardSlice, packingLoops;
614   if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops,
615                                                 backwardSlice, packingLoops)))
616     return failure();
617 
618   // Update actual number of loops, which may be smaller.
619   nLoops = packingLoops.size();
620 
621   Location loc = padTensorOp->getLoc();
622   RankedTensorType paddedTensorType = padTensorOp.getResultType();
623   unsigned paddedRank = paddedTensorType.getRank();
624 
625   // Backward slice is a topologically sorted list of ops starting at
626   // `outermostEnclosingForOp`.
627   Operation *outermostEnclosingForOp = backwardSlice.front();
628   // IP just before the outermost loop considered that we hoist above.
629   OpBuilder b(outermostEnclosingForOp);
630 
631   // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
632   // padding.
633   SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize);
634   // TODO: go grab dims when necessary, for now PadTensorOp returns a static
635   // tensor.
636   llvm::append_range(packedShape, paddedTensorType.getShape());
637   auto packedTensorType =
638       RankedTensorType::get(packedShape, paddedTensorType.getElementType());
639   auto dynamicSizes =
640       llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) {
641         return buildLoopTripCount(b, cast<scf::ForOp>(op));
642       }));
643   Value packedTensor = b.create<linalg::InitTensorOp>(
644       loc, dynamicSizes, packedTensorType.getShape(),
645       packedTensorType.getElementType());
646 
647   // Clone the operations involved in the backward slice, iteratively stepping
648   // into the loops that we encounter.
649   // The implementation proceeds in a stack-like fashion:
650   //   1. Iteratively clone and step into the loops, pushing the `packedTensor`
651   //      deeper in the stack.
652   //   2. Create a SubTensorInsert at the top of the stack.
653   //   3. Iteratively pop and yield the result of the SubTensorInsertOp across
654   //     the cloned loops.
655   SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
656   clonedLoopIvs.reserve(nLoops);
657   leadingPackedTensorIndexings.reserve(nLoops);
658   BlockAndValueMapping bvm;
659   // Stack step 1. iteratively clone loops and push `packedTensor`.
660   // Insert `padTensorOp` into the backwardSlice so we clone it too.
661   backwardSlice.insert(padTensorOp);
662   for (Operation *op : backwardSlice) {
663     if (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op)) {
664       b.clone(*op, bvm);
665       continue;
666     }
667     // TODO: support more cases as they appear.
668     auto forOp = dyn_cast<scf::ForOp>(op);
669     assert(forOp && "Expected scf::ForOp when hoisting pad ops");
670     // Unused loop, just skip it.
671     if (!packingLoops.contains(forOp))
672       continue;
673     auto clonedForOp =
674         b.create<scf::ForOp>(loc, forOp.lowerBound(), forOp.upperBound(),
675                              forOp.step(), packedTensor);
676     assert(clonedForOp->getNumRegions() == 1);
677     clonedLoopIvs.push_back(clonedForOp.getInductionVar());
678     b.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
679     leadingPackedTensorIndexings.push_back(
680         buildLoopIterationCount(b, clonedForOp));
681     bvm.map(forOp.getInductionVar(), clonedLoopIvs.back());
682     packedTensor = clonedForOp.getRegionIterArgs().front();
683   }
684 
685   // Stack step 2. create SubTensorInsertOp at the top of the stack.
686   // offsets = [clonedLoopIvs, 0 .. 0].
687   SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(),
688                                     leadingPackedTensorIndexings.end());
689   offsets.append(paddedRank, b.getIndexAttr(0));
690   // sizes = [1 .. 1, paddedShape].
691   SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1));
692   for (int64_t sz : paddedTensorType.getShape()) {
693     // TODO: go grab dims when necessary, for now PadTensorOp returns a static
694     // tensor.
695     assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes");
696     sizes.push_back(b.getIndexAttr(sz));
697   }
698   // strides = [1 .. 1].
699   SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1));
700 
701   Value inserted =
702       b.create<SubTensorInsertOp>(loc, bvm.lookup(padTensorOp.result()),
703                                   packedTensor, offsets, sizes, strides);
704 
705   // Stack step 3. iteratively pop the stack and propagate the yield.
706   Value valueToYield = inserted;
707   for (Value iv : llvm::reverse(clonedLoopIvs)) {
708     auto forOp = scf::getForInductionVarOwner(iv);
709     b.setInsertionPointToEnd(&forOp.getRegion().front());
710     b.create<scf::YieldOp>(loc, valueToYield);
711     valueToYield = forOp.getResult(0);
712   }
713 
714   // Now the packed tensor is ready, replace the original padding op by a
715   // 1x..x1 SubTensor [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
716   b.setInsertionPoint(padTensorOp);
717   SmallVector<Value> loopIterationCounts =
718       llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {
719         return buildLoopIterationCount(b, cast<scf::ForOp>(loop));
720       }));
721   // offsets = [originalLoopIvs, 0 .. 0].
722   offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end());
723   offsets.append(paddedRank, b.getIndexAttr(0));
724   // sizes = [1 .. 1, paddedShape] (definedabove).
725   // strides = [1 .. 1] (defined above)
726   packedTensor =
727       scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
728   padTensorOp.replaceAllUsesWith(
729       b.create<SubTensorOp>(loc, padTensorOp.getResultType(), packedTensor,
730                             offsets, sizes, strides)
731           ->getResult(0));
732 
733   Operation *toErase = padTensorOp;
734 
735   // Make the newly cloned `padTensorOp` available to the caller.
736   padTensorOp =
737       cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp());
738 
739   toErase->erase();
740 
741   return success();
742 }
743