1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Utils.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/SCF/Utils.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Dialect/Vector/VectorUtils.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/Dominance.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "mlir/Transforms/LoopUtils.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Debug.h"
30 
31 using llvm::dbgs;
32 
33 #define DEBUG_TYPE "linalg-hoisting"
34 
35 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
36 
37 using namespace mlir;
38 using namespace mlir::linalg;
39 
40 namespace {
41 /// Represents a unit of hoistable TransferWriteOp. This may comprise other
42 /// instructions that need to be hoisted too.
43 struct HoistableWrite {
44   vector::TransferWriteOp transferWriteOp;
45   SubTensorInsertOp subTensorInsertOp;
46 };
47 /// Represents a unit of hoistable TransferReadOp. This may comprise other
48 /// instructions that need to be hoisted too.
49 struct HoistableRead {
50   vector::TransferReadOp transferReadOp;
51   SubTensorOp subTensorOp;
52 };
53 } // namespace
54 
55 /// Return true if op1 and op2 are the same constant or the same SSA value.
56 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
57   auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
58     Attribute attr = ofr.dyn_cast<Attribute>();
59     // Note: isa+cast-like pattern allows writing the condition below as 1 line.
60     if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
61       attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
62     if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
63       return intAttr.getValue().getSExtValue();
64     return llvm::None;
65   };
66   auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
67   if (cst1 && cst2 && *cst1 == *cst2)
68     return true;
69   auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
70   return v1 && v2 && v1 == v2;
71 }
72 
73 /// Return true is all offsets, sizes and strides are equal.
74 static bool sameOffsetsSizesAndStrides(SubTensorOp s, SubTensorInsertOp si) {
75   if (s.static_offsets().size() != si.static_offsets().size())
76     return false;
77   if (s.static_sizes().size() != si.static_sizes().size())
78     return false;
79   if (s.static_strides().size() != si.static_strides().size())
80     return false;
81   for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
82     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
83       return false;
84   for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
85     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
86       return false;
87   for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
88     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
89       return false;
90   return true;
91 }
92 
93 /// Look for a HoistableRead, in the given tensor uses, accessing the same
94 /// offset as the HoistableWrite.
95 static HoistableRead findMatchingTransferRead(HoistableWrite write,
96                                               Value srcTensor) {
97   assert(write.transferWriteOp &&
98          "expected hoistable write to have a .transfer_write");
99 
100   LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
101                     << *write.transferWriteOp.getOperation() << "\n");
102   if (write.subTensorInsertOp)
103     LLVM_DEBUG(DBGS() << "findMatchingTransferRead subTensorInsertOp: "
104                       << *write.subTensorInsertOp.getOperation() << "\n");
105 
106   for (Operation *user : srcTensor.getUsers()) {
107     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
108                       << "\n");
109 
110     // If HoistableWrite involves a SubTensorInsertOp, we need to find a
111     // matching SubTensorOp.
112     SubTensorOp subTensorOp;
113     Operation *maybeTransferReadUser = user;
114     if (write.subTensorInsertOp) {
115       subTensorOp = dyn_cast<SubTensorOp>(user);
116       if (!subTensorOp || subTensorOp.getResult().getType() !=
117                               write.subTensorInsertOp.source().getType())
118         continue;
119 
120       LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
121                         << *subTensorOp << " vs " << *write.subTensorInsertOp
122                         << "\n");
123       if (!sameOffsetsSizesAndStrides(subTensorOp, write.subTensorInsertOp))
124         continue;
125 
126       LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
127       // If we got here, subTensorOp is hoistable iff it has exactly 2 uses:
128       //   1. the transfer_write we want to hoist.
129       //   2. a matching transfer_read.
130       // Anything else, we skip.
131       bool skip = false;
132       Operation *otherUser = nullptr;
133       for (Operation *u : subTensorOp->getUsers()) {
134         if (u == write.transferWriteOp)
135           continue;
136         if (otherUser) {
137           skip = true;
138           break;
139         }
140         otherUser = u;
141       }
142       if (skip || !otherUser)
143         continue;
144       maybeTransferReadUser = otherUser;
145     }
146 
147     LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
148                       << "\n");
149     auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
150     if (read && read.indices() == write.transferWriteOp.indices() &&
151         read.getVectorType() == write.transferWriteOp.getVectorType())
152       return HoistableRead{read, subTensorOp};
153   }
154   return HoistableRead();
155 }
156 
157 /// Check if the chunk of data inserted by the HoistableWrite are read by any
158 /// other op than the HoistableRead candidate.
159 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
160                                            HoistableRead candidateRead,
161                                            BlockArgument tensorArg) {
162   // Make sure none of the other uses read the part of the tensor modified
163   // by the transfer_write.
164   llvm::SmallVector<Value::use_range, 1> uses;
165   uses.push_back(tensorArg.getUses());
166   while (!uses.empty()) {
167     for (OpOperand &use : uses.pop_back_val()) {
168       Operation *user = use.getOwner();
169       // Skip the candidate use, only inspect the "other" uses.
170       if (user == candidateRead.transferReadOp ||
171           user == candidateRead.subTensorOp || user == write.transferWriteOp ||
172           user == write.subTensorInsertOp)
173         continue;
174       // Consider all transitive uses through a subtensor / subtensor_insert.
175       // TODO: atm we just bail because a stronger analysis is needed for these
176       // cases.
177       if (isa<SubTensorOp, SubTensorInsertOp>(user))
178         return true;
179       // Consider all transitive uses through a vector.transfer_write.
180       if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
181         uses.push_back(writeUser->getResult(0).getUses());
182         continue;
183       }
184       // Consider all nested uses through an scf::ForOp. We may have
185       // pass-through tensor arguments left from previous level of
186       // hoisting.
187       if (auto forUser = dyn_cast<scf::ForOp>(user)) {
188         Value arg = forUser.getLoopBody().getArgument(
189             use.getOperandNumber() - forUser.getNumControlOperands() +
190             /*iv value*/ 1);
191         uses.push_back(arg.getUses());
192         continue;
193       }
194       // Follow the use yield as long as it doesn't escape the original
195       // region.
196       scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
197       if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
198                            yieldUser->getParentOp())) {
199         Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
200         uses.push_back(ret.getUses());
201         continue;
202       }
203       auto read = dyn_cast<vector::TransferReadOp>(user);
204       if (!read || !isDisjointTransferIndices(
205                        cast<VectorTransferOpInterface>(read.getOperation()),
206                        cast<VectorTransferOpInterface>(
207                            write.transferWriteOp.getOperation()))) {
208         return true;
209       }
210     }
211   }
212   return false;
213 }
214 
215 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
216 /// Return the null HoistableWrite() if it is not comprised of a
217 /// vector.transfer_write + optional subtensor_insert or if any of the indexings
218 /// is `forOp`-dependent.
219 static HoistableWrite
220 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
221                                         OpOperand &yieldOperand) {
222   Value v = yieldOperand.get();
223   if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
224     // Indexing must not depend on `forOp`.
225     for (Value operand : write.indices())
226       if (!forOp.isDefinedOutsideOfLoop(operand))
227         return HoistableWrite();
228 
229     return HoistableWrite{write, nullptr};
230   }
231 
232   if (auto subTensorInsertOp = v.getDefiningOp<SubTensorInsertOp>()) {
233     // Inserted subTensor must come from vector.transfer_write.
234     auto write =
235         subTensorInsertOp.source().getDefiningOp<vector::TransferWriteOp>();
236     if (!write)
237       return HoistableWrite();
238 
239     // Tensor inserted into must be a BBArg at position matching yieldOperand's.
240     auto bbArg = subTensorInsertOp.dest().dyn_cast<BlockArgument>();
241     if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
242         bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
243       return HoistableWrite();
244 
245     // Indexing inserted into must not depend on `forOp`.
246     for (Value operand : subTensorInsertOp->getOperands().drop_front(
247              SubTensorInsertOp::getOffsetSizeAndStrideStartOperandIndex()))
248       if (!forOp.isDefinedOutsideOfLoop(operand))
249         return HoistableWrite();
250 
251     return HoistableWrite{write, subTensorInsertOp};
252   }
253 
254   return HoistableWrite();
255 }
256 
257 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
258 static void hoistReadWrite(HoistableRead read, HoistableWrite write,
259                            BlockArgument tensorBBArg) {
260   scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
261   assert(read.transferReadOp && write.transferWriteOp &&
262          "expected transfer_read and transfer_write ops to be set");
263   assert(((read.subTensorOp && write.subTensorInsertOp) ||
264           (!read.subTensorOp && !write.subTensorInsertOp)) &&
265          "expected matching subtensor / subtensor_insert");
266   LLVM_DEBUG(DBGS() << "In forOp:\n"
267                     << *forOp.getOperation()
268                     << "\nHoist: " << *read.transferReadOp.getOperation()
269                     << "\nHoist: " << *write.transferWriteOp.getOperation()
270                     << "\nInvolving: " << tensorBBArg << "\n");
271 
272   // If a read subtensor is present, hoist it.
273   if (read.subTensorOp && failed(forOp.moveOutOfLoop({read.subTensorOp})))
274     llvm_unreachable("Unexpected failure moving subtensor out of loop");
275 
276   // Hoist the transfer_read op.
277   if (failed(forOp.moveOutOfLoop({read.transferReadOp})))
278     llvm_unreachable("Unexpected failure moving transfer read out of loop");
279 
280   // TODO: don't hardcode /*numIvs=*/1.
281   assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
282   unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
283 
284   // Update the source tensor.
285   if (read.subTensorOp)
286     read.subTensorOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
287   else
288     read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
289 
290   // Hoist write after.
291   if (write.subTensorInsertOp)
292     write.subTensorInsertOp->moveAfter(forOp);
293   write.transferWriteOp->moveAfter(forOp);
294 
295   // Update the yield.
296   auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
297   if (write.subTensorInsertOp)
298     yieldOp->setOperand(initArgNumber, write.subTensorInsertOp.dest());
299   else
300     yieldOp->setOperand(initArgNumber, write.transferWriteOp.source());
301 
302   // Rewrite `loop` with additional new yields.
303   OpBuilder b(read.transferReadOp);
304   auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(),
305                                      write.transferWriteOp.vector());
306   // Transfer write has been hoisted, need to update the vector and tensor
307   // source. Replace the result of the loop to use the new tensor created
308   // outside the loop.
309   // Depending on whether a subtensor_insert is present or not, it carries the
310   // update on the tensor operands.
311   if (write.subTensorInsertOp) {
312     newForOp.getResult(initArgNumber)
313         .replaceAllUsesWith(write.subTensorInsertOp.getResult());
314     write.transferWriteOp.sourceMutable().assign(read.subTensorOp.result());
315     write.subTensorInsertOp.destMutable().assign(read.subTensorOp.source());
316   } else {
317     newForOp.getResult(initArgNumber)
318         .replaceAllUsesWith(write.transferWriteOp.getResult(0));
319     write.transferWriteOp.sourceMutable().assign(
320         newForOp.getResult(initArgNumber));
321   }
322 
323   // Always update with the newly yield tensor and vector.
324   write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back());
325 }
326 
327 // To hoist transfer op on tensor the logic can be significantly simplified
328 // compared to the case on buffer. The transformation follows this logic:
329 // 1. Look for transfer_write with a single use from ForOp yield
330 // 2. Check the uses of the matching block argument and look for a transfer_read
331 // with the same indices.
332 // 3. Check that all the other uses of the tensor argument are either disjoint
333 // tensor_read or transfer_write. For transfer_write uses recurse to make sure
334 // the new tensor has the same restrictions on its uses.
335 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
336 // After this transformation the scf.forOp may have unused arguments that can be
337 // remove by the canonicalization pass.
338 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
339   bool changed = true;
340   while (changed) {
341     changed = false;
342     func.walk([&](scf::ForOp forOp) {
343       Operation *yield = forOp.getBody()->getTerminator();
344       for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
345         OpOperand &ret = yield->getOpOperand(it.index());
346         HoistableWrite write =
347             getLoopInvariantTransferWriteOpDefining(forOp, ret);
348         if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
349           continue;
350         LLVM_DEBUG(dbgs() << "\n";
351                    DBGS() << "Candidate write for hoisting: "
352                           << *write.transferWriteOp.getOperation() << "\n");
353         if (write.subTensorInsertOp)
354           LLVM_DEBUG(DBGS() << "Candidate subtensor_insert for hoisting: "
355                             << *write.subTensorInsertOp.getOperation() << "\n");
356         if (llvm::any_of(write.transferWriteOp.indices(),
357                          [&forOp](Value index) {
358                            return !forOp.isDefinedOutsideOfLoop(index);
359                          }))
360           continue;
361         // Find a read with the same type and indices.
362         HoistableRead matchingRead =
363             findMatchingTransferRead(write, it.value());
364         // Make sure none of the other uses read the part of the tensor modified
365         // by the transfer_write.
366         if (!matchingRead.transferReadOp ||
367             tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
368           continue;
369 
370         LLVM_DEBUG(DBGS() << "Start hoisting\n");
371         hoistReadWrite(matchingRead, write, it.value());
372         changed = true;
373         forOp.erase();
374 
375         // Need to interrupt and restart: erasing the loop messes up the walk.
376         return WalkResult::interrupt();
377       }
378       return WalkResult::advance();
379     });
380     // Apply canonicalization so the newForOp + yield folds immediately, thus
381     // cleaning up the IR and potentially enabling more hoisting.
382     if (changed) {
383       RewritePatternSet patterns(func->getContext());
384       scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
385       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
386     }
387   }
388 }
389 
390 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
391   bool changed = true;
392   while (changed) {
393     changed = false;
394 
395     func.walk([&](vector::TransferReadOp transferRead) {
396       if (!transferRead.getShapedType().isa<MemRefType>())
397         return WalkResult::advance();
398 
399       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
400                         << *transferRead.getOperation() << "\n");
401       auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
402       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
403                         << "\n");
404       if (!loop)
405         return WalkResult::advance();
406 
407       if (failed(moveLoopInvariantCode(
408               cast<LoopLikeOpInterface>(loop.getOperation()))))
409         llvm_unreachable(
410             "Unexpected failure to move invariant code out of loop");
411 
412       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
413                         << "\n");
414 
415       SetVector<Operation *> forwardSlice;
416       getForwardSlice(transferRead.getOperation(), &forwardSlice);
417 
418       // Look for the last TransferWriteOp in the forwardSlice of
419       // `transferRead` that operates on the same memref.
420       vector::TransferWriteOp transferWrite;
421       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
422         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
423         if (!candidateWrite || candidateWrite.source() != transferRead.source())
424           continue;
425         transferWrite = candidateWrite;
426       }
427 
428       // All operands of the TransferRead must be defined outside of the loop.
429       for (auto operand : transferRead.getOperands())
430         if (!loop.isDefinedOutsideOfLoop(operand))
431           return WalkResult::advance();
432 
433       // Only hoist transfer_read / transfer_write pairs for now.
434       if (!transferWrite)
435         return WalkResult::advance();
436 
437       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
438                         << "\n");
439 
440       // Approximate aliasing by checking that:
441       //   1. indices are the same,
442       //   2. no other operations in the loop access the same memref except
443       //      for transfer_read/transfer_write accessing statically disjoint
444       //      slices.
445       if (transferRead.indices() != transferWrite.indices() &&
446           transferRead.getVectorType() == transferWrite.getVectorType())
447         return WalkResult::advance();
448 
449       // TODO: may want to memoize this information for performance but it
450       // likely gets invalidated often.
451       DominanceInfo dom(loop);
452       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
453         return WalkResult::advance();
454       for (auto &use : transferRead.source().getUses()) {
455         if (!dom.properlyDominates(loop, use.getOwner()))
456           continue;
457         if (use.getOwner() == transferRead.getOperation() ||
458             use.getOwner() == transferWrite.getOperation())
459           continue;
460         if (auto transferWriteUse =
461                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
462           if (!isDisjointTransferSet(
463                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
464                   cast<VectorTransferOpInterface>(
465                       transferWriteUse.getOperation())))
466             return WalkResult::advance();
467         } else if (auto transferReadUse =
468                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
469           if (!isDisjointTransferSet(
470                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
471                   cast<VectorTransferOpInterface>(
472                       transferReadUse.getOperation())))
473             return WalkResult::advance();
474         } else {
475           // Unknown use, we cannot prove that it doesn't alias with the
476           // transferRead/transferWrite operations.
477           return WalkResult::advance();
478         }
479       }
480 
481       // Hoist read before.
482       if (failed(loop.moveOutOfLoop({transferRead})))
483         llvm_unreachable(
484             "Unexpected failure to move transfer read out of loop");
485 
486       // Hoist write after.
487       transferWrite->moveAfter(loop);
488 
489       // Rewrite `loop` with new yields by cloning and erase the original loop.
490       OpBuilder b(transferRead);
491       auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
492                                          transferWrite.vector());
493 
494       // Transfer write has been hoisted, need to update the written value to
495       // the value yielded by the newForOp.
496       transferWrite.vector().replaceAllUsesWith(
497           newForOp.getResults().take_back()[0]);
498 
499       changed = true;
500       loop.erase();
501       // Need to interrupt and restart because erasing the loop messes up the
502       // walk.
503       return WalkResult::interrupt();
504     });
505   }
506 }
507 
508 /// Return success if `v` is a value that is only transitively defined by ops of
509 /// type in `OpTypeList`.
510 template <typename... OpTypeList>
511 static bool backwardsSliceOnlyHasOpsOfType(scf::ForOp outerLimit, Value v) {
512   // Compute a backward slice up to, but not including, `outerLimit`.
513   SetVector<Operation *> backwardSlice;
514   getBackwardSlice(v, &backwardSlice, [&](Operation *op) {
515     return outerLimit->isProperAncestor(op);
516   });
517   // Traverse the backward slice and ensure we can perform the computation to
518   // hoist.
519   for (Operation *op : backwardSlice) {
520     if (isa<OpTypeList...>(op))
521       continue;
522     LLVM_DEBUG(DBGS() << "Abort: unadmissible op in slice " << *op << "\n");
523     return false;
524   }
525   return true;
526 }
527 
528 bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
529   return outer.isDefinedOutsideOfLoop(v) || v.getDefiningOp<ConstantOp>();
530 }
531 
532 /// Compute the tightest lower bound with quantities that are all defined
533 /// outside of `outer`.
534 /// Return null if such a bound cannot be computed.
535 Value computeLoopIndependentLowerBound(OpBuilder &b, scf::ForOp outer,
536                                        Value v) {
537   if (isDefinedOutsideOrConstant(outer, v))
538     return v;
539   return Value();
540 }
541 
542 /// Compute the tightest upper bound with quantities that are all defined
543 /// outside of `outer`.
544 /// Expects all ops in the backward slice of `v` up to `outer` to be either
545 /// scf.for, affine.min or affine.apply.
546 static Value computeLoopIndependentUpperBound(OpBuilder &b, scf::ForOp outer,
547                                               Value v) {
548   if (isDefinedOutsideOrConstant(outer, v))
549     return v;
550 
551   LLVM_DEBUG(DBGS() << "Begin loopIndependentUpperBound for: " << v << "\n");
552 
553   bool ok =
554       backwardsSliceOnlyHasOpsOfType<scf::ForOp, AffineMinOp, AffineApplyOp>(
555           outer, v);
556   assert(ok && "expected to only be defined by scf::ForOp and AffineMinOp");
557   (void)ok;
558 
559   // Compute a backward slice up to, but not including, `outer`.
560   SetVector<Operation *> backwardSlice;
561   getBackwardSlice(v, &backwardSlice,
562                    [&](Operation *op) { return outer->isProperAncestor(op); });
563   backwardSlice.insert(v.getDefiningOp());
564 
565   OpBuilder::InsertionGuard g(b);
566   b.setInsertionPoint(outer);
567   Value res = v;
568   BlockAndValueMapping bvm;
569   for (Operation *op : backwardSlice) {
570     if (isa<scf::ForOp>(op))
571       continue;
572     if (isa<AffineApplyOp>(op)) {
573       b.clone(*op, bvm);
574       continue;
575     }
576     auto sliceMinOp = cast<AffineMinOp>(op);
577     GetMinMaxExprFn getSCFMinMax = [&](Value value,
578                                        SmallVectorImpl<Value> &dims,
579                                        SmallVectorImpl<Value> &symbols) {
580       return getSCFMinMaxExpr(value, dims, symbols, [&](Operation *op) {
581         return outer->isAncestor(op);
582       });
583     };
584     // Perform the substitution of the operands of AffineMinOp.
585     auto mapAndOperands = substituteMin(sliceMinOp, getSCFMinMax);
586     SmallVector<Value> resultOperands = mapAndOperands.dims;
587     llvm::append_range(resultOperands, mapAndOperands.symbols);
588     AffineMap map = mapAndOperands.map;
589     canonicalizeMapAndOperands(&map, &resultOperands);
590     map = simplifyAffineMap(map);
591     res = b.create<AffineMinOp>(
592         outer->getLoc(), map,
593         llvm::to_vector<4>(llvm::map_range(resultOperands, [&](Value operand) {
594           return bvm.lookupOrDefault(operand);
595         })));
596     bvm.map(sliceMinOp, res);
597   }
598   LLVM_DEBUG(DBGS() << "End loopIndependentUpperBound with: " << res << "\n");
599   return res;
600 }
601 
602 /// Return the number of iterations in the loop (ub - lb).ceilDiv(step).
603 /// The returned Value is guaranteed not to depend on any loop comprised in
604 /// [`outer`, `forOp`].
605 /// Return null if such a loop-independent quantity cannot be computed.
606 static Value buildLoopTripCount(OpBuilder &b, scf::ForOp outer,
607                                 scf::ForOp forOp) {
608   MLIRContext *ctx = forOp->getContext();
609   AffineExpr lb, ub, step;
610   bindDims(ctx, lb, ub);
611   bindSymbols(ctx, step);
612   Value lbVal = computeLoopIndependentLowerBound(b, outer, forOp.lowerBound()),
613         ubVal = computeLoopIndependentUpperBound(b, outer, forOp.upperBound()),
614         stepVal = forOp.step();
615   if (!lbVal || !ubVal || !stepVal)
616     return Value();
617   auto loc = forOp->getLoc();
618   Value res = b.create<AffineApplyOp>(loc, (ub - lb).ceilDiv(step),
619                                       ValueRange{lbVal, ubVal, stepVal});
620   return res;
621 }
622 
623 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
624 /// The returned Value is guaranteed not to depend on any loop comprised in
625 /// [`outer`, `forOp`].
626 /// Return null if such a loop-independent quantity cannot be computed.
627 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer,
628                                      scf::ForOp forOp) {
629   MLIRContext *ctx = forOp->getContext();
630   AffineExpr iv, lb, step;
631   bindDims(ctx, iv, lb);
632   bindSymbols(ctx, step);
633   Value ivVal = forOp.getInductionVar(),
634         lbVal = computeLoopIndependentLowerBound(b, outer, forOp.lowerBound()),
635         stepVal = forOp.step();
636   if (!ivVal || !lbVal || !stepVal)
637     return Value();
638   auto loc = forOp->getLoc();
639   return b.create<AffineApplyOp>(loc, (iv - lb).ceilDiv(step),
640                                  ValueRange{ivVal, lbVal, stepVal});
641 }
642 
643 /// Ensure prerequisites that guarantee pad op hoisting can occur.
644 /// Return failure in the cases when we cannot perform hoisting; i.e. if either:
645 ///   1. There exists a use of `padTensorOp` that is not a linalg input operand.
646 ///   2. There isn't an enclosing `outermostEnclosingForOp` loop.
647 ///   3. There exists an op with a region that is dominated by
648 ///   `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a
649 ///    LinalgOp.
650 ///   4. There exists an op with side effects that is dominated by
651 ///   `outermostEnclosingForOp` and that isn't a LoopLikeInterface.
652 ///   5. The lower bound, upper bound and step of all the loops involved in the
653 ///   hoisting can be
654 ///
655 /// While ensuring prerequisites:
656 ///   1. Fill the `backwardSlice` to contain the topologically sorted ops
657 ///   dominated by `outermostEnclosingForOp`.
658 ///   2. Fill the `packingLoops` to contain only the enclosing loops of
659 ///   `backwardSlice` whose IV is actually used in computing padding. Loops that
660 ///   remain in `backwardSlice` but that are not in `packingLoops` are
661 ///   dimensions of reuse.
662 static LogicalResult
663 hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels,
664                                    SetVector<Operation *> &backwardSlice,
665                                    SetVector<Operation *> &packingLoops,
666                                    SmallVector<Value> &dynamicTensorSizes) {
667   // Bail on any use that isn't an input of a Linalg op.
668   // Hoisting of inplace updates happens after vectorization.
669   for (OpOperand &use : padTensorOp.result().getUses()) {
670     auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
671     if (!linalgUser || !linalgUser.isInputTensor(&use))
672       return failure();
673   }
674 
675   // Get at most nLevels of enclosing loops.
676   SmallVector<LoopLikeOpInterface> reverseEnclosingLoops;
677   Operation *outermostEnclosingForOp = nullptr,
678             *nextEnclosingForOp =
679                 padTensorOp->getParentOfType<LoopLikeOpInterface>();
680   while (nLevels-- > 0 && nextEnclosingForOp) {
681     outermostEnclosingForOp = nextEnclosingForOp;
682     reverseEnclosingLoops.push_back(outermostEnclosingForOp);
683     nextEnclosingForOp =
684         nextEnclosingForOp->getParentOfType<LoopLikeOpInterface>();
685   }
686   if (!outermostEnclosingForOp)
687     return failure();
688 
689   // Get the backwards slice from `padTensorOp` that is dominated by the
690   // outermost enclosing loop.
691   DominanceInfo domInfo(outermostEnclosingForOp);
692   getBackwardSlice(padTensorOp.getOperation(), &backwardSlice,
693                    [&](Operation *op) {
694                      return domInfo.dominates(outermostEnclosingForOp, op);
695                    });
696 
697   // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp.
698   if (llvm::any_of(backwardSlice, [](Operation *op) {
699         return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) &&
700                !isa<LinalgOp>(op);
701       }))
702     return failure();
703 
704   // Filter out the loops whose induction variable is not used to compute the
705   // padded result. As a first approximation, just look for IVs that have no use
706   // in the backwardSlice.
707   // These are the dimensions of reuse that we can exploit to reduce the amount
708   // of work / memory.
709   // TODO: would this optimization compose better as a canonicalization?
710   for (LoopLikeOpInterface loop : llvm::reverse(reverseEnclosingLoops)) {
711     auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
712     if (!forOp)
713       continue;
714     for (Operation *user : forOp.getInductionVar().getUsers()) {
715       if (backwardSlice.contains(user)) {
716         packingLoops.insert(forOp);
717         break;
718       }
719     }
720   }
721 
722   // Backward slice is a topologically sorted list of ops starting at
723   // `outermostEnclosingForOp`.
724   assert(outermostEnclosingForOp == backwardSlice.front());
725 
726   scf::ForOp outer = cast<scf::ForOp>(outermostEnclosingForOp);
727   if (llvm::any_of(packingLoops, [&](Operation *op) {
728         scf::ForOp forOp = cast<scf::ForOp>(op);
729         Value lb = forOp.lowerBound(), ub = forOp.upperBound(),
730               step = forOp.step();
731         return !isDefinedOutsideOrConstant(outer, lb) ||
732                !(isDefinedOutsideOrConstant(outer, ub) ||
733                  backwardsSliceOnlyHasOpsOfType<scf::ForOp, AffineMinOp,
734                                                 AffineApplyOp>(outer, ub)) ||
735                !isDefinedOutsideOrConstant(outer, step);
736       }))
737     return failure();
738 
739   // IP just before the outermost loop considered that we hoist above.
740   OpBuilder b(outermostEnclosingForOp);
741   dynamicTensorSizes =
742       llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) {
743         return buildLoopTripCount(b, cast<scf::ForOp>(outermostEnclosingForOp),
744                                   cast<scf::ForOp>(op));
745       }));
746   // Assert all loop trip counts can be computed.
747   if (!llvm::all_of(dynamicTensorSizes, [](Value v) { return v; }))
748     llvm_unreachable("loop independence prerequisite not met");
749   return success();
750 }
751 
752 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
753                                                   unsigned nLoops) {
754   SmallVector<Value> dynamicTensorSizes;
755   SetVector<Operation *> backwardSlice, packingLoops;
756   if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops,
757                                                 backwardSlice, packingLoops,
758                                                 dynamicTensorSizes)))
759     return failure();
760 
761   // Update actual number of loops, which may be smaller.
762   nLoops = packingLoops.size();
763 
764   Location loc = padTensorOp->getLoc();
765   RankedTensorType paddedTensorType = padTensorOp.getResultType();
766   unsigned paddedRank = paddedTensorType.getRank();
767 
768   // Backward slice is a topologically sorted list of ops starting at
769   // `outermostEnclosingForOp`.
770   Operation *outermostEnclosingForOp = backwardSlice.front();
771   // IP just before the outermost loop considered that we hoist above.
772   OpBuilder b(outermostEnclosingForOp);
773 
774   // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
775   // padding.
776   SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize);
777   // TODO: go grab dims when necessary, for now PadTensorOp returns a static
778   // tensor.
779   llvm::append_range(packedShape, paddedTensorType.getShape());
780   auto packedTensorType =
781       RankedTensorType::get(packedShape, paddedTensorType.getElementType());
782   Value packedTensor = b.create<linalg::InitTensorOp>(
783       loc, dynamicTensorSizes, packedTensorType.getShape(),
784       packedTensorType.getElementType());
785 
786   // Clone the operations involved in the backward slice, iteratively stepping
787   // into the loops that we encounter.
788   // The implementation proceeds in a stack-like fashion:
789   //   1. Iteratively clone and step into the loops, pushing the `packedTensor`
790   //      deeper in the stack.
791   //   2. Create a SubTensorInsert at the top of the stack.
792   //   3. Iteratively pop and yield the result of the SubTensorInsertOp across
793   //     the cloned loops.
794   SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
795   clonedLoopIvs.reserve(nLoops);
796   leadingPackedTensorIndexings.reserve(nLoops);
797   BlockAndValueMapping bvm;
798   // Insert `padTensorOp` into the backwardSlice so we clone it too.
799   backwardSlice.insert(padTensorOp);
800   // Stack step 1. iteratively clone loops and push `packedTensor`.
801   for (Operation *op : backwardSlice) {
802     // Specifically sit out in the subtenso(packedTensor) case: this is the
803     // piece we seek to replace.
804     if (auto subTensor = dyn_cast<SubTensorOp>(op))
805       if (bvm.lookupOrDefault(subTensor.source()) == packedTensor)
806         continue;
807     auto effects = dyn_cast<MemoryEffectOpInterface>(op);
808     bool hasNoEffects = !effects || effects.hasNoEffect();
809     if (hasNoEffects &&
810         (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op))) {
811       b.clone(*op, bvm);
812       continue;
813     }
814     // TODO: support more cases as they appear.
815     auto forOp = dyn_cast<scf::ForOp>(op);
816     assert(forOp && "Expected scf::ForOp when hoisting pad ops");
817     // Unused loop, just skip it.
818     if (!packingLoops.contains(forOp))
819       continue;
820 
821     auto clonedForOp =
822         b.create<scf::ForOp>(loc, bvm.lookupOrDefault(forOp.lowerBound()),
823                              bvm.lookupOrDefault(forOp.upperBound()),
824                              bvm.lookupOrDefault(forOp.step()), packedTensor);
825     // Map the induction var, region args and results to the `clonedForOp`.
826     bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
827     bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());
828     bvm.map(forOp.getResults(), clonedForOp.getResults());
829     assert(clonedForOp->getNumRegions() == 1);
830     clonedLoopIvs.push_back(clonedForOp.getInductionVar());
831 
832     b.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
833     Value loopIndependentIterationCount = buildLoopIterationCount(
834         b, cast<scf::ForOp>(outermostEnclosingForOp), clonedForOp);
835     // Assert the loop-independent iteration count can be computed.
836     if (!loopIndependentIterationCount)
837       llvm_unreachable("loop independence prerequisite not met");
838     leadingPackedTensorIndexings.push_back(loopIndependentIterationCount);
839     packedTensor = clonedForOp.getRegionIterArgs().front();
840   }
841 
842   // Stack step 2. create SubTensorInsertOp at the top of the stack.
843   // offsets = [clonedLoopIvs, 0 .. 0].
844   SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(),
845                                     leadingPackedTensorIndexings.end());
846   offsets.append(paddedRank, b.getIndexAttr(0));
847   // sizes = [1 .. 1, paddedShape].
848   SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1));
849   for (int64_t sz : paddedTensorType.getShape()) {
850     // TODO: go grab dims when necessary, for now PadTensorOp returns a static
851     // tensor.
852     assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes");
853     sizes.push_back(b.getIndexAttr(sz));
854   }
855   // strides = [1 .. 1].
856   SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1));
857 
858   Value inserted =
859       b.create<SubTensorInsertOp>(loc, bvm.lookup(padTensorOp.result()),
860                                   packedTensor, offsets, sizes, strides);
861 
862   // Stack step 3. iteratively pop the stack and propagate the yield.
863   Value valueToYield = inserted;
864   for (Value iv : llvm::reverse(clonedLoopIvs)) {
865     auto forOp = scf::getForInductionVarOwner(iv);
866     b.setInsertionPointToEnd(&forOp.getRegion().front());
867     b.create<scf::YieldOp>(loc, valueToYield);
868     valueToYield = forOp.getResult(0);
869   }
870 
871   // Now the packed tensor is ready, replace the original padding op by a
872   // 1x..x1 SubTensor [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
873   b.setInsertionPoint(padTensorOp);
874   SmallVector<Value> loopIterationCounts =
875       llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {
876         return buildLoopIterationCount(
877             b, cast<scf::ForOp>(outermostEnclosingForOp),
878             cast<scf::ForOp>(loop));
879       }));
880   // Assert all loop iteration counts can be computed.
881   if (llvm::any_of(loopIterationCounts, [](Value v) { return !v; }))
882     llvm_unreachable("loop independence prerequisite not met");
883   // offsets = [originalLoopIvs, 0 .. 0].
884   offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end());
885   offsets.append(paddedRank, b.getIndexAttr(0));
886   // sizes = [1 .. 1, paddedShape] (definedabove).
887   // strides = [1 .. 1] (defined above)
888   packedTensor =
889       scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
890   padTensorOp.replaceAllUsesWith(
891       b.create<SubTensorOp>(loc, padTensorOp.getResultType(), packedTensor,
892                             offsets, sizes, strides)
893           ->getResult(0));
894 
895   Operation *toErase = padTensorOp;
896 
897   // Make the newly cloned `padTensorOp` available to the caller.
898   padTensorOp =
899       cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp());
900 
901   toErase->erase();
902 
903   return success();
904 }
905