1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/SCF/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Dialect/Vector/VectorUtils.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "mlir/Transforms/LoopUtils.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Debug.h"
28 
29 using llvm::dbgs;
30 
31 #define DEBUG_TYPE "linalg-hoisting"
32 
33 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 void mlir::linalg::hoistViewAllocOps(FuncOp func) {
39   bool changed = true;
40   while (changed) {
41     changed = false;
42     func.walk([&changed](Operation *op) {
43       if (!isa<AllocOp, AllocaOp, DeallocOp>(op))
44         return;
45 
46       LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n");
47       auto loop = dyn_cast<scf::ForOp>(op->getParentOp());
48       LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n");
49 
50       // Only hoist out of immediately enclosing scf::ForOp.
51       if (!loop)
52         return;
53 
54       // If any operand is defined inside the loop don't hoist.
55       if (llvm::any_of(op->getOperands(), [&](Value v) {
56             return !loop.isDefinedOutsideOfLoop(v);
57           }))
58         return;
59 
60       LLVM_DEBUG(DBGS() << "All operands defined outside \n");
61 
62       // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist.
63       Value v;
64       if (op->getNumResults() > 0) {
65         assert(op->getNumResults() == 1 && "Unexpected multi-result alloc");
66         v = op->getResult(0);
67       }
68       if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) {
69             return isa<ViewLikeOpInterface, DeallocOp>(operand.getOwner());
70           })) {
71         LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n");
72         return;
73       }
74 
75       // Move AllocOp before the loop.
76       if (isa<AllocOp, AllocaOp>(op))
77         (void)loop.moveOutOfLoop({op});
78       else // Move DeallocOp outside of the loop.
79         op->moveAfter(loop);
80       changed = true;
81     });
82   }
83 }
84 
85 namespace {
86 /// Represents a unit of hoistable TransferWriteOp. This may comprise other
87 /// instructions that need to be hoisted too.
88 struct HoistableWrite {
89   vector::TransferWriteOp transferWriteOp;
90   SubTensorInsertOp subTensorInsertOp;
91 };
92 /// Represents a unit of hoistable TransferReadOp. This may comprise other
93 /// instructions that need to be hoisted too.
94 struct HoistableRead {
95   vector::TransferReadOp transferReadOp;
96   SubTensorOp subTensorOp;
97 };
98 } // namespace
99 
100 /// Return true if op1 and op2 are the same constant or the same SSA value.
101 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
102   auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
103     Attribute attr = ofr.dyn_cast<Attribute>();
104     // Note: isa+cast-like pattern allows writing the condition below as 1 line.
105     if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
106       attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
107     if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
108       return intAttr.getValue().getSExtValue();
109     return llvm::None;
110   };
111   auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
112   if (cst1 && cst2 && *cst1 == *cst2)
113     return true;
114   auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
115   return v1 && v2 && v1 == v2;
116 }
117 
118 /// Return true is all offsets, sizes and strides are equal.
119 static bool sameOffsetsSizesAndStrides(SubTensorOp s, SubTensorInsertOp si) {
120   if (s.static_offsets().size() != si.static_offsets().size())
121     return false;
122   if (s.static_sizes().size() != si.static_sizes().size())
123     return false;
124   if (s.static_strides().size() != si.static_strides().size())
125     return false;
126   for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
127     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
128       return false;
129   for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
130     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
131       return false;
132   for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
133     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
134       return false;
135   return true;
136 }
137 
138 /// Look for a HoistableRead, in the given tensor uses, accessing the same
139 /// offset as the HoistableWrite.
140 static HoistableRead findMatchingTransferRead(HoistableWrite write,
141                                               Value srcTensor) {
142   assert(write.transferWriteOp &&
143          "expected hoistable write to have a .transfer_write");
144 
145   LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
146                     << *write.transferWriteOp.getOperation() << "\n");
147   if (write.subTensorInsertOp)
148     LLVM_DEBUG(DBGS() << "findMatchingTransferRead subTensorInsertOp: "
149                       << *write.subTensorInsertOp.getOperation() << "\n");
150 
151   for (Operation *user : srcTensor.getUsers()) {
152     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
153                       << "\n");
154 
155     // If HoistableWrite involves a SubTensorInsertOp, we need to find a
156     // matching SubTensorOp.
157     SubTensorOp subTensorOp;
158     Operation *maybeTransferReadUser = user;
159     if (write.subTensorInsertOp) {
160       subTensorOp = dyn_cast<SubTensorOp>(user);
161       if (!subTensorOp || subTensorOp.getResult().getType() !=
162                               write.subTensorInsertOp.source().getType())
163         continue;
164 
165       LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
166                         << *subTensorOp << " vs " << *write.subTensorInsertOp
167                         << "\n");
168       if (!sameOffsetsSizesAndStrides(subTensorOp, write.subTensorInsertOp))
169         continue;
170 
171       LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
172       // If we got here, subTensorOp is hoistable iff it has exactly 2 uses:
173       //   1. the transfer_write we want to hoist.
174       //   2. a matching transfer_read.
175       // Anything else, we skip.
176       bool skip = false;
177       Operation *otherUser = nullptr;
178       for (Operation *u : subTensorOp->getUsers()) {
179         if (u == write.transferWriteOp)
180           continue;
181         if (otherUser) {
182           skip = true;
183           break;
184         }
185         otherUser = u;
186       }
187       if (skip || !otherUser)
188         continue;
189       maybeTransferReadUser = otherUser;
190     }
191 
192     LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
193                       << "\n");
194     auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
195     if (read && read.indices() == write.transferWriteOp.indices() &&
196         read.getVectorType() == write.transferWriteOp.getVectorType())
197       return HoistableRead{read, subTensorOp};
198   }
199   return HoistableRead();
200 }
201 
202 /// Check if the chunk of data inserted by the HoistableWrite are read by any
203 /// other op than the HoistableRead candidate.
204 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
205                                            HoistableRead candidateRead,
206                                            BlockArgument tensorArg) {
207   // Make sure none of the other uses read the part of the tensor modified
208   // by the transfer_write.
209   llvm::SmallVector<Value::use_range, 1> uses;
210   uses.push_back(tensorArg.getUses());
211   while (!uses.empty()) {
212     for (OpOperand &use : uses.pop_back_val()) {
213       Operation *user = use.getOwner();
214       // Skip the candidate use, only inspect the "other" uses.
215       if (user == candidateRead.transferReadOp ||
216           user == candidateRead.subTensorOp || user == write.transferWriteOp ||
217           user == write.subTensorInsertOp)
218         continue;
219       // Consider all transitive uses through a subtensor / subtensor_insert.
220       // TODO: atm we just bail because a stronger analysis is needed for these
221       // cases.
222       if (isa<SubTensorOp, SubTensorInsertOp>(user))
223         return true;
224       // Consider all transitive uses through a vector.transfer_write.
225       if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
226         uses.push_back(writeUser->getResult(0).getUses());
227         continue;
228       }
229       // Consider all nested uses through an scf::ForOp. We may have
230       // pass-through tensor arguments left from previous level of
231       // hoisting.
232       if (auto forUser = dyn_cast<scf::ForOp>(user)) {
233         Value arg = forUser.getLoopBody().getArgument(
234             use.getOperandNumber() - forUser.getNumControlOperands() +
235             /*iv value*/ 1);
236         uses.push_back(arg.getUses());
237         continue;
238       }
239       // Follow the use yield as long as it doesn't escape the original
240       // region.
241       scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
242       if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
243                            yieldUser->getParentOp())) {
244         Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
245         uses.push_back(ret.getUses());
246         continue;
247       }
248       auto read = dyn_cast<vector::TransferReadOp>(user);
249       if (!read || !isDisjointTransferIndices(
250                        cast<VectorTransferOpInterface>(read.getOperation()),
251                        cast<VectorTransferOpInterface>(
252                            write.transferWriteOp.getOperation()))) {
253         return true;
254       }
255     }
256   }
257   return false;
258 }
259 
260 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
261 /// Return the null HoistableWrite() if it is not comprised of a
262 /// vector.transfer_write + optional subtensor_insert or if any of the indexings
263 /// is `forOp`-dependent.
264 static HoistableWrite
265 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
266                                         OpOperand &yieldOperand) {
267   Value v = yieldOperand.get();
268   if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
269     // Indexing must not depend on `forOp`.
270     for (Value operand : write.indices())
271       if (!forOp.isDefinedOutsideOfLoop(operand))
272         return HoistableWrite();
273 
274     return HoistableWrite{write, nullptr};
275   }
276 
277   if (auto subTensorInsertOp = v.getDefiningOp<SubTensorInsertOp>()) {
278     // Inserted subTensor must come from vector.transfer_write.
279     auto write =
280         subTensorInsertOp.source().getDefiningOp<vector::TransferWriteOp>();
281     if (!write)
282       return HoistableWrite();
283 
284     // Tensor inserted into must be a BBArg at position matching yieldOperand's.
285     auto bbArg = subTensorInsertOp.dest().dyn_cast<BlockArgument>();
286     if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
287         bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
288       return HoistableWrite();
289 
290     // Indexing inserted into must not depend on `forOp`.
291     for (Value operand : subTensorInsertOp->getOperands().drop_front(
292              SubTensorInsertOp::getOffsetSizeAndStrideStartOperandIndex()))
293       if (!forOp.isDefinedOutsideOfLoop(operand))
294         return HoistableWrite();
295 
296     return HoistableWrite{write, subTensorInsertOp};
297   }
298 
299   return HoistableWrite();
300 }
301 
302 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
303 static void hoistReadWrite(HoistableRead read, HoistableWrite write,
304                            BlockArgument tensorBBArg) {
305   scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
306   assert(read.transferReadOp && write.transferWriteOp &&
307          "expected transfer_read and transfer_write ops to be set");
308   assert(((read.subTensorOp && write.subTensorInsertOp) ||
309           (!read.subTensorOp && !write.subTensorInsertOp)) &&
310          "expected matching subtensor / subtensor_insert");
311   LLVM_DEBUG(DBGS() << "In forOp:\n"
312                     << *forOp.getOperation()
313                     << "\nHoist: " << *read.transferReadOp.getOperation()
314                     << "\nHoist: " << *write.transferWriteOp.getOperation()
315                     << "\nInvolving: " << tensorBBArg << "\n");
316 
317   // If a read subtensor is present, hoist it.
318   if (read.subTensorOp && failed(forOp.moveOutOfLoop({read.subTensorOp})))
319     llvm_unreachable("Unexpected failure moving subtensor out of loop");
320 
321   // Hoist the transfer_read op.
322   if (failed(forOp.moveOutOfLoop({read.transferReadOp})))
323     llvm_unreachable("Unexpected failure moving transfer read out of loop");
324 
325   // TODO: don't hardcode /*numIvs=*/1.
326   assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
327   unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
328 
329   // Update the source tensor.
330   if (read.subTensorOp)
331     read.subTensorOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
332   else
333     read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
334 
335   // Hoist write after.
336   if (write.subTensorInsertOp)
337     write.subTensorInsertOp->moveAfter(forOp);
338   write.transferWriteOp->moveAfter(forOp);
339 
340   // Update the yield.
341   auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
342   if (write.subTensorInsertOp)
343     yieldOp->setOperand(initArgNumber, write.subTensorInsertOp.dest());
344   else
345     yieldOp->setOperand(initArgNumber, write.transferWriteOp.source());
346 
347   // Rewrite `loop` with additional new yields.
348   OpBuilder b(read.transferReadOp);
349   auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(),
350                                      write.transferWriteOp.vector());
351   // Transfer write has been hoisted, need to update the vector and tensor
352   // source. Replace the result of the loop to use the new tensor created
353   // outside the loop.
354   // Depending on whether a subtensor_insert is present or not, it carries the
355   // update on the tensor operands.
356   if (write.subTensorInsertOp) {
357     newForOp.getResult(initArgNumber)
358         .replaceAllUsesWith(write.subTensorInsertOp.getResult());
359     write.transferWriteOp.sourceMutable().assign(read.subTensorOp.result());
360     write.subTensorInsertOp.destMutable().assign(read.subTensorOp.source());
361   } else {
362     newForOp.getResult(initArgNumber)
363         .replaceAllUsesWith(write.transferWriteOp.getResult(0));
364     write.transferWriteOp.sourceMutable().assign(
365         newForOp.getResult(initArgNumber));
366   }
367 
368   // Always update with the newly yield tensor and vector.
369   write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back());
370 }
371 
372 // To hoist transfer op on tensor the logic can be significantly simplified
373 // compared to the case on buffer. The transformation follows this logic:
374 // 1. Look for transfer_write with a single use from ForOp yield
375 // 2. Check the uses of the matching block argument and look for a transfer_read
376 // with the same indices.
377 // 3. Check that all the other uses of the tensor argument are either disjoint
378 // tensor_read or transfer_write. For transfer_write uses recurse to make sure
379 // the new tensor has the same restrictions on its uses.
380 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
381 // After this transformation the scf.forOp may have unused arguments that can be
382 // remove by the canonicalization pass.
383 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
384   bool changed = true;
385   while (changed) {
386     changed = false;
387     func.walk([&](scf::ForOp forOp) {
388       Operation *yield = forOp.getBody()->getTerminator();
389       for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
390         OpOperand &ret = yield->getOpOperand(it.index());
391         HoistableWrite write =
392             getLoopInvariantTransferWriteOpDefining(forOp, ret);
393         if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
394           continue;
395         LLVM_DEBUG(dbgs() << "\n";
396                    DBGS() << "Candidate write for hoisting: "
397                           << *write.transferWriteOp.getOperation() << "\n");
398         if (write.subTensorInsertOp)
399           LLVM_DEBUG(DBGS() << "Candidate subtensor_insert for hoisting: "
400                             << *write.subTensorInsertOp.getOperation() << "\n");
401         if (llvm::any_of(write.transferWriteOp.indices(),
402                          [&forOp](Value index) {
403                            return !forOp.isDefinedOutsideOfLoop(index);
404                          }))
405           continue;
406         // Find a read with the same type and indices.
407         HoistableRead matchingRead =
408             findMatchingTransferRead(write, it.value());
409         // Make sure none of the other uses read the part of the tensor modified
410         // by the transfer_write.
411         if (!matchingRead.transferReadOp ||
412             tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
413           continue;
414 
415         LLVM_DEBUG(DBGS() << "Start hoisting\n");
416         hoistReadWrite(matchingRead, write, it.value());
417         changed = true;
418         forOp.erase();
419 
420         // Need to interrupt and restart: erasing the loop messes up the walk.
421         return WalkResult::interrupt();
422       }
423       return WalkResult::advance();
424     });
425     // Apply canonicalization so the newForOp + yield folds immediately, thus
426     // cleaning up the IR and potentially enabling more hoisting.
427     if (changed) {
428       OwningRewritePatternList patterns;
429       scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
430       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
431     }
432   }
433 }
434 
435 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
436   bool changed = true;
437   while (changed) {
438     changed = false;
439 
440     func.walk([&](vector::TransferReadOp transferRead) {
441       if (!transferRead.getShapedType().isa<MemRefType>())
442         return WalkResult::advance();
443 
444       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
445                         << *transferRead.getOperation() << "\n");
446       auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
447       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
448                         << "\n");
449       if (!loop)
450         return WalkResult::advance();
451 
452       if (failed(moveLoopInvariantCode(
453               cast<LoopLikeOpInterface>(loop.getOperation()))))
454         llvm_unreachable(
455             "Unexpected failure to move invariant code out of loop");
456 
457       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
458                         << "\n");
459 
460       llvm::SetVector<Operation *> forwardSlice;
461       getForwardSlice(transferRead.getOperation(), &forwardSlice);
462 
463       // Look for the last TransferWriteOp in the forwardSlice of
464       // `transferRead` that operates on the same memref.
465       vector::TransferWriteOp transferWrite;
466       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
467         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
468         if (!candidateWrite || candidateWrite.source() != transferRead.source())
469           continue;
470         transferWrite = candidateWrite;
471       }
472 
473       // All operands of the TransferRead must be defined outside of the loop.
474       for (auto operand : transferRead.getOperands())
475         if (!loop.isDefinedOutsideOfLoop(operand))
476           return WalkResult::advance();
477 
478       // Only hoist transfer_read / transfer_write pairs for now.
479       if (!transferWrite)
480         return WalkResult::advance();
481 
482       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
483                         << "\n");
484 
485       // Approximate aliasing by checking that:
486       //   1. indices are the same,
487       //   2. no other operations in the loop access the same memref except
488       //      for transfer_read/transfer_write accessing statically disjoint
489       //      slices.
490       if (transferRead.indices() != transferWrite.indices() &&
491           transferRead.getVectorType() == transferWrite.getVectorType())
492         return WalkResult::advance();
493 
494       // TODO: may want to memoize this information for performance but it
495       // likely gets invalidated often.
496       DominanceInfo dom(loop);
497       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
498         return WalkResult::advance();
499       for (auto &use : transferRead.source().getUses()) {
500         if (!dom.properlyDominates(loop, use.getOwner()))
501           continue;
502         if (use.getOwner() == transferRead.getOperation() ||
503             use.getOwner() == transferWrite.getOperation())
504           continue;
505         if (auto transferWriteUse =
506                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
507           if (!isDisjointTransferSet(
508                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
509                   cast<VectorTransferOpInterface>(
510                       transferWriteUse.getOperation())))
511             return WalkResult::advance();
512         } else if (auto transferReadUse =
513                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
514           if (!isDisjointTransferSet(
515                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
516                   cast<VectorTransferOpInterface>(
517                       transferReadUse.getOperation())))
518             return WalkResult::advance();
519         } else {
520           // Unknown use, we cannot prove that it doesn't alias with the
521           // transferRead/transferWrite operations.
522           return WalkResult::advance();
523         }
524       }
525 
526       // Hoist read before.
527       if (failed(loop.moveOutOfLoop({transferRead})))
528         llvm_unreachable(
529             "Unexpected failure to move transfer read out of loop");
530 
531       // Hoist write after.
532       transferWrite->moveAfter(loop);
533 
534       // Rewrite `loop` with new yields by cloning and erase the original loop.
535       OpBuilder b(transferRead);
536       auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
537                                          transferWrite.vector());
538 
539       // Transfer write has been hoisted, need to update the written value to
540       // the value yielded by the newForOp.
541       transferWrite.vector().replaceAllUsesWith(
542           newForOp.getResults().take_back()[0]);
543 
544       changed = true;
545       loop.erase();
546       // Need to interrupt and restart because erasing the loop messes up the
547       // walk.
548       return WalkResult::interrupt();
549     });
550   }
551 }
552 
553 /// Ensure prerequisites that guarantee pad op hoisting can occur.
554 /// Return failure in the cases when we cannot perform hoisting; i.e. if either:
555 ///   1. There exists a use of `padTensorOp` that is not a linalg input operand.
556 ///   2. There isn't an enclosing `outermostEnclosingForOp` loop.
557 ///   3. There exists an op with a region that is dominated by
558 ///   `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a
559 ///    LinalgOp.
560 ///   3. There exists an op with side effects that is dominated by
561 ///    `outermostEnclosingForOp` and that isn't a LoopLikeInterface.
562 ///
563 /// While ensuring prerequisites:
564 ///   1. Fill the `backwardSlice` to contain the topologically sorted ops
565 ///   dominated by `outermostEnclosingForOp`.
566 ///   2. Fill the `packingLoops` to contain only the enclosing loops of
567 ///   `backwardSlice` whose IV is actually used in computing padding. Loops that
568 ///   remain in `backwardSlice` but that are not in `packingLoops` are
569 ///   dimensions of reuse.
570 static LogicalResult
571 hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels,
572                                    llvm::SetVector<Operation *> &backwardSlice,
573                                    llvm::SetVector<Operation *> &packingLoops) {
574   // Bail on any use that isn't an input of a Linalg op.
575   // Hoisting of inplace updates happens after vectorization.
576   for (OpOperand &use : padTensorOp.result().getUses()) {
577     auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
578     if (!linalgUser || !linalgUser.isInputTensor(&use))
579       return failure();
580   }
581 
582   // Get at most nLevels of enclosing loops.
583   SmallVector<LoopLikeOpInterface> reverseEnclosingLoops;
584   Operation *outermostEnclosingForOp = nullptr,
585             *nextEnclosingForOp =
586                 padTensorOp->getParentOfType<LoopLikeOpInterface>();
587   while (nLevels-- > 0 && nextEnclosingForOp) {
588     outermostEnclosingForOp = nextEnclosingForOp;
589     reverseEnclosingLoops.push_back(outermostEnclosingForOp);
590     nextEnclosingForOp =
591         nextEnclosingForOp->getParentOfType<LoopLikeOpInterface>();
592   }
593   if (!outermostEnclosingForOp)
594     return failure();
595 
596   // Get the backwards slice from `padTensorOp` that is dominated by the
597   // outermost enclosing loop.
598   DominanceInfo domInfo(outermostEnclosingForOp);
599   getBackwardSlice(padTensorOp.getOperation(), &backwardSlice,
600                    [&](Operation *op) {
601                      return domInfo.dominates(outermostEnclosingForOp, op);
602                    });
603 
604   // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp.
605   if (llvm::any_of(backwardSlice, [](Operation *op) {
606         return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) &&
607                !isa<LinalgOp>(op);
608       }))
609     return failure();
610 
611   // Filter out the loops whose induction variable is not used to compute the
612   // padded result. As a first approximation, just look for IVs that have no use
613   // in the backwardSlice.
614   // These are the dimensions of reuse that we can exploit to reduce the amount
615   // of work / memory.
616   // TODO: would this optimization compose better as a canonicalization?
617   for (LoopLikeOpInterface loop : reverseEnclosingLoops) {
618     auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
619     if (!forOp)
620       continue;
621     for (Operation *user : forOp.getInductionVar().getUsers()) {
622       if (backwardSlice.contains(user)) {
623         packingLoops.insert(forOp);
624         break;
625       }
626     }
627   }
628 
629   // Backward slice is a topologically sorted list of ops starting at
630   // `outermostEnclosingForOp`.
631   assert(outermostEnclosingForOp == backwardSlice.front());
632 
633   return success();
634 }
635 
636 /// Return the number of iterations in the loop (ub - lb).ceilDiv(step).
637 static Value buildLoopTripCount(OpBuilder &b, scf::ForOp forOp) {
638   MLIRContext *ctx = forOp->getContext();
639   AffineExpr lb, ub, step;
640   bindDims(ctx, lb, ub);
641   bindSymbols(ctx, step);
642   return b.create<AffineApplyOp>(
643       forOp->getLoc(), AffineMap::get(2, 1, {(ub - lb).ceilDiv(step)}, ctx),
644       ValueRange{forOp.lowerBound(), forOp.upperBound(), forOp.step()});
645 }
646 
647 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
648 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp forOp) {
649   MLIRContext *ctx = forOp->getContext();
650   AffineExpr iv, lb, step;
651   bindDims(ctx, iv, lb);
652   bindSymbols(ctx, step);
653   return b.create<AffineApplyOp>(
654       forOp->getLoc(), AffineMap::get(2, 1, {(iv - lb).ceilDiv(step)}, ctx),
655       ValueRange{forOp.getInductionVar(), forOp.lowerBound(), forOp.step()});
656 }
657 
658 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
659                                                   unsigned nLoops) {
660   llvm::SetVector<Operation *> backwardSlice, packingLoops;
661   if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops,
662                                                 backwardSlice, packingLoops)))
663     return failure();
664 
665   // Update actual number of loops, which may be smaller.
666   nLoops = packingLoops.size();
667 
668   Location loc = padTensorOp->getLoc();
669   RankedTensorType paddedTensorType = padTensorOp.getResultType();
670   unsigned paddedRank = paddedTensorType.getRank();
671 
672   // Backward slice is a topologically sorted list of ops starting at
673   // `outermostEnclosingForOp`.
674   Operation *outermostEnclosingForOp = backwardSlice.front();
675   // IP just before the outermost loop considered that we hoist above.
676   OpBuilder b(outermostEnclosingForOp);
677 
678   // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
679   // padding.
680   SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize);
681   // TODO: go grab dims when necessary, for now PadTensorOp returns a static
682   // tensor.
683   llvm::append_range(packedShape, paddedTensorType.getShape());
684   auto packedTensorType =
685       RankedTensorType::get(packedShape, paddedTensorType.getElementType());
686   auto dynamicSizes =
687       llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) {
688         return buildLoopTripCount(b, cast<scf::ForOp>(op));
689       }));
690   Value packedTensor = b.create<linalg::InitTensorOp>(
691       loc, dynamicSizes, packedTensorType.getShape(),
692       packedTensorType.getElementType());
693 
694   // Clone the operations involved in the backward slice, iteratively stepping
695   // into the loops that we encounter.
696   // The implementation proceeds in a stack-like fashion:
697   //   1. Iteratively clone and step into the loops, pushing the `packedTensor`
698   //      deeper in the stack.
699   //   2. Create a SubTensorInsert at the top of the stack.
700   //   3. Iteratively pop and yield the result of the SubTensorInsertOp across
701   //     the cloned loops.
702   SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
703   clonedLoopIvs.reserve(nLoops);
704   leadingPackedTensorIndexings.reserve(nLoops);
705   BlockAndValueMapping bvm;
706   // Stack step 1. iteratively clone loops and push `packedTensor`.
707   // Insert `padTensorOp` into the backwardSlice so we clone it too.
708   backwardSlice.insert(padTensorOp);
709   for (Operation *op : backwardSlice) {
710     if (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op)) {
711       b.clone(*op, bvm);
712       continue;
713     }
714     // TODO: support more cases as they appear.
715     auto forOp = dyn_cast<scf::ForOp>(op);
716     assert(forOp && "Expected scf::ForOp when hoisting pad ops");
717     // Unused loop, just skip it.
718     if (!packingLoops.contains(forOp))
719       continue;
720     auto clonedForOp =
721         b.create<scf::ForOp>(loc, forOp.lowerBound(), forOp.upperBound(),
722                              forOp.step(), packedTensor);
723     assert(clonedForOp->getNumRegions() == 1);
724     clonedLoopIvs.push_back(clonedForOp.getInductionVar());
725     b.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
726     leadingPackedTensorIndexings.push_back(
727         buildLoopIterationCount(b, clonedForOp));
728     bvm.map(forOp.getInductionVar(), clonedLoopIvs.back());
729     packedTensor = clonedForOp.getRegionIterArgs().front();
730   }
731 
732   // Stack step 2. create SubTensorInsertOp at the top of the stack.
733   // offsets = [clonedLoopIvs, 0 .. 0].
734   SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(),
735                                     leadingPackedTensorIndexings.end());
736   offsets.append(paddedRank, b.getIndexAttr(0));
737   // sizes = [1 .. 1, paddedShape].
738   SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1));
739   for (int64_t sz : paddedTensorType.getShape()) {
740     // TODO: go grab dims when necessary, for now PadTensorOp returns a static
741     // tensor.
742     assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes");
743     sizes.push_back(b.getIndexAttr(sz));
744   }
745   // strides = [1 .. 1].
746   SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1));
747 
748   Value inserted =
749       b.create<SubTensorInsertOp>(loc, bvm.lookup(padTensorOp.result()),
750                                   packedTensor, offsets, sizes, strides);
751 
752   // Stack step 3. iteratively pop the stack and propagate the yield.
753   Value valueToYield = inserted;
754   for (Value iv : llvm::reverse(clonedLoopIvs)) {
755     auto forOp = scf::getForInductionVarOwner(iv);
756     b.setInsertionPointToEnd(&forOp.getRegion().front());
757     b.create<scf::YieldOp>(loc, valueToYield);
758     valueToYield = forOp.getResult(0);
759   }
760 
761   // Now the packed tensor is ready, replace the original padding op by a
762   // 1x..x1 SubTensor [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
763   b.setInsertionPoint(padTensorOp);
764   SmallVector<Value> loopIterationCounts =
765       llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {
766         return buildLoopIterationCount(b, cast<scf::ForOp>(loop));
767       }));
768   // offsets = [originalLoopIvs, 0 .. 0].
769   offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end());
770   offsets.append(paddedRank, b.getIndexAttr(0));
771   // sizes = [1 .. 1, paddedShape] (definedabove).
772   // strides = [1 .. 1] (defined above)
773   packedTensor =
774       scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
775   padTensorOp.replaceAllUsesWith(
776       b.create<SubTensorOp>(loc, padTensorOp.getResultType(), packedTensor,
777                             offsets, sizes, strides)
778           ->getResult(0));
779 
780   Operation *toErase = padTensorOp;
781 
782   // Make the newly cloned `padTensorOp` available to the caller.
783   padTensorOp =
784       cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp());
785 
786   toErase->erase();
787 
788   return success();
789 }
790