1 //===- HoistPadding.cpp - Hoisting transformation for PadTensorOp ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting padding operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
14 #include "mlir/Analysis/AffineStructures.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
17 #include "mlir/Dialect/Affine/Utils.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SCF/Utils.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Vector/VectorOps.h"
25 #include "mlir/Dialect/Vector/VectorUtils.h"
26 #include "mlir/IR/AsmState.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/Dominance.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "mlir/Transforms/LoopUtils.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/Debug.h"
33 
34 using llvm::dbgs;
35 
36 #define DEBUG_TYPE "hoist-padding"
37 
38 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
39 
40 using namespace mlir;
41 using namespace mlir::linalg;
42 
43 /// Analysis class to support PadTensorOp hoisting across multiple enclosing
44 /// loops. The failure conditions are:
45 ///   1. Pad op has a use that is not an input of a LinalgOp.
46 ///   2. There is no immediately enclosing scf::ForOp.
47 ///   3. The backward slice from the pad op to the scf::ForOp to hoist above
48 ///   contains
49 ///      an unknown op with a region.
50 ///   4. The backward slice from the pad op to the scf::ForOp to hoist above is
51 ///   empty.
52 /// Other cases succeed and will trigger hoisting of the pad op.
53 struct HoistingAnalysis {
54   HoistingAnalysis(PadTensorOp padTensorOp, int nLevels);
55 
56   bool isValid() { return valid; }
57 
58   /// Footprint of the packedTensor, computed from the packingLoops and
59   /// `backwardSlice`.
60   FailureOr<SmallVector<Value>> getPackedTensorSizes(ImplicitLocOpBuilder &b);
61 
62   /// The padTensorOp that needs to be hoisted.
63   PadTensorOp padTensorOp;
64 
65   /// The maximum number of immediately enclosing scf::ForOp to hoist over.
66   int nLevels;
67 
68   /// The outermost loop, determined by `nLevels` above which `padTensorOp` will
69   /// be hoisted.
70   scf::ForOp outermostEnclosingForOp;
71 
72   /// Backward slice rooted at `padTensorOp` and nested under
73   /// `outermostEnclosingForOp`.
74   SetVector<Operation *> backwardSlice;
75 
76   /// The scf::ForOp immediately enclosing `padTensorOp` such that:
77   ///  1. they are nested under `outermostEnclosingForOp` (inclusive)
78   ///  2. whose induction variable is used, directly or indirectly, in the
79   ///     computation of `padTensorOp`.
80   /// The span of these loops determines the footprint of the packed tensor.
81   /// SmallSetVector<scf::ForOp> packingLoops;
82   SetVector<scf::ForOp, SmallVector<scf::ForOp>, DenseSet<Operation *>>
83       packingLoops;
84 
85 private:
86   /// Encodes whether the analysis is valid and hoisting can proceed.
87   bool valid;
88 };
89 
90 /// Return true if all uses of `padTensorOp` are an input tensor of some
91 /// LinalgOp.
92 static bool isOnlyUsedAsInputOfLinalgOp(PadTensorOp padTensorOp) {
93   for (OpOperand &use : padTensorOp.result().getUses()) {
94     auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
95     if (!linalgUser || !linalgUser.isInputTensor(&use)) {
96       LLVM_DEBUG(DBGS() << "Found a use of " << *(padTensorOp)
97                         << "\nthat is not an input tensor of a LinalgOp, "
98                         << "cannot hoist\n"
99                         << *(use.getOwner()) << "\n");
100       return false;
101     }
102   }
103   return true;
104 }
105 
106 /// Return at most nLevels of immediately enclosing scf::ForOp loops.
107 /// Stops at the first parent that is not an scf::ForOp.
108 /// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm.
109 /// Control-flow and other containing ops with regions are not modeled atm.
110 static void
111 getAtMostNEnclosingLoops(PadTensorOp padTensorOp, int nLevels,
112                          SmallVector<scf::ForOp> &reverseEnclosingLoops) {
113   AsmState state(padTensorOp->getParentOfType<mlir::FuncOp>());
114   (void)state;
115   scf::ForOp outermostEnclosingForOp = nullptr;
116   Operation *nextEnclosingOp = padTensorOp->getParentOp();
117   while (nLevels-- > 0 &&
118          (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
119     LLVM_DEBUG(
120         DBGS() << "loops: ";
121         outermostEnclosingForOp.getInductionVar().printAsOperand(dbgs(), state);
122         dbgs() << "\n");
123     reverseEnclosingLoops.push_back(outermostEnclosingForOp);
124     nextEnclosingOp = outermostEnclosingForOp->getParentOp();
125   }
126 }
127 
128 HoistingAnalysis::HoistingAnalysis(PadTensorOp padTensorOp, int nLevels)
129     : padTensorOp(padTensorOp), nLevels(nLevels), valid(false) {
130   AsmState state(padTensorOp->getParentOfType<mlir::FuncOp>());
131   (void)state;
132 
133   // Bail on any use that isn't an input of a Linalg op.
134   // Hoisting of inplace updates happens after vectorization.
135   if (!isOnlyUsedAsInputOfLinalgOp(padTensorOp))
136     return;
137 
138   // Get at most nLevels of immediately enclosing loops.
139   SmallVector<scf::ForOp> reverseEnclosingLoops;
140   getAtMostNEnclosingLoops(padTensorOp, nLevels, reverseEnclosingLoops);
141   if (reverseEnclosingLoops.empty()) {
142     LLVM_DEBUG(DBGS() << "No immediately enclosing loop -> skip\n");
143     return;
144   }
145 
146   outermostEnclosingForOp = reverseEnclosingLoops.back();
147 
148   // Get all the ops in the backwards slice starting from `padTensorOp` and that
149   // are dominated by the outermost enclosing loop.
150   // Bail on any op with a region that is not either a scf::ForOp or a LinalgOp.
151   bool analysisFailure = false;
152   DominanceInfo domInfo(outermostEnclosingForOp);
153   getBackwardSlice(
154       padTensorOp.getOperation(), &backwardSlice, [&](Operation *op) {
155         if (!domInfo.dominates(outermostEnclosingForOp, op))
156           return false;
157         if (op != padTensorOp && op->getNumRegions() > 0 &&
158             !isa<scf::ForOp, LinalgOp>(op)) {
159           analysisFailure = true;
160           LLVM_DEBUG(DBGS()
161                      << "Unsupported op with region: " << *op << " -> skip\n");
162           return false;
163         }
164         return true;
165       });
166 
167   if (analysisFailure || backwardSlice.empty())
168     return;
169 
170   // Backward slice is a topologically sorted list of ops starting at
171   // `outermostEnclosingForOp`.
172   assert(outermostEnclosingForOp == backwardSlice.front());
173 
174   // Filter out the loops whose induction variable is not used to compute the
175   // padded result. As a first approximation, just look for IVs that have no use
176   // in the backwardSlice.
177   // These are the dimensions of reuse that we can exploit to reduce the amount
178   // of copy / memory.
179   for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops)) {
180     for (Operation *user : forOp.getInductionVar().getUsers()) {
181       if (backwardSlice.contains(user)) {
182         packingLoops.insert(forOp);
183         break;
184       }
185     }
186   }
187 
188   // The analysis is valid and hoisting can occur.
189   valid = true;
190 }
191 
192 static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
193   return outer.isDefinedOutsideOfLoop(v) || v.getDefiningOp<ConstantOp>();
194 }
195 
196 /// For each loop in `loops`, determine the ops involved in the construction of
197 /// its upper bound---up to the outerLimit loop--- and fold them as new
198 /// inequalities in the constraint set.
199 /// This is achieved by computing the backwardSlice of the loop's upper bound
200 /// and iteratively folding each op in reverse topological order to guarantee
201 /// use-def ordering.
202 /// As operations are folded in, their result is projected out of the
203 /// constraints set.
204 /// The following operations are supported:
205 ///   - scf::ForOp are simply skipped.
206 ///   - AffineApplyOp are composed to replace the result by an equality.
207 ///   - AffineMinOp are composed by adding each entry as an upper bound.
208 /// If any other operation is met, return failure.
209 // TODO: extend on a per-need basis.
210 static LogicalResult
211 foldUpperBoundsIntoConstraintsSet(FlatAffineValueConstraints &constraints,
212                                   scf::ForOp outerLimit,
213                                   ArrayRef<scf::ForOp> loops) {
214   SetVector<Value> toProjectOut;
215   for (scf::ForOp loop : loops) {
216     auto ub = loop.upperBound();
217     if (isDefinedOutsideOrConstant(outerLimit, ub))
218       continue;
219 
220     // Compute a backward slice up to, but not including, `outerLimit`.
221     SetVector<Operation *> backwardSlice;
222     getBackwardSlice(ub, &backwardSlice, [&](Operation *op) {
223       return outerLimit->isProperAncestor(op);
224     });
225     backwardSlice.insert(ub.getDefiningOp());
226 
227     // Iterate over all ops in the slice and compose them in the constraints.
228     for (Operation *op : llvm::reverse(backwardSlice)) {
229       if (!isa<scf::ForOp, AffineApplyOp, AffineMinOp>(op))
230         return failure();
231       if (isa<scf::ForOp>(op))
232         continue;
233       // Ensure there is a
234       auto ensureIdFailed = [&](Value v) {
235         if (constraints.containsId(v)) {
236           unsigned pos;
237           constraints.findId(v, &pos);
238           return pos >= constraints.getNumDimIds();
239         }
240         constraints.appendDimId(v);
241         return false;
242       };
243 
244       // Ensure all ids exist and add results for later projection.
245       if (llvm::any_of(op->getResults(), ensureIdFailed) ||
246           llvm::any_of(op->getOperands(), ensureIdFailed))
247         return failure();
248 
249       // All supported ops have 1 result.
250       // TODO: extend when needed.
251       toProjectOut.insert(op->getResult(0));
252 
253       // Compose supported ops.
254       if (auto affineApplyOp = dyn_cast<AffineApplyOp>(op)) {
255         AffineValueMap avm(affineApplyOp.getAffineMap(),
256                            affineApplyOp.getOperands(),
257                            affineApplyOp.getResult());
258         if (failed(constraints.composeMap(&avm)))
259           return failure();
260         continue;
261       }
262       auto affineMinOp = cast<AffineMinOp>(op);
263       unsigned pos;
264       bool foundMinOp = constraints.findId(affineMinOp.getResult(), &pos);
265       (void)foundMinOp;
266       assert(foundMinOp);
267       AffineMap alignedMap = constraints.computeAlignedMap(
268           affineMinOp.getAffineMap(), affineMinOp.getOperands());
269       if (failed(
270               constraints.addBound(FlatAffineConstraints::UB, pos, alignedMap)))
271         return failure();
272     }
273   }
274   for (Value v : toProjectOut)
275     constraints.projectOut(v);
276   return success();
277 }
278 
279 // Footprint of the packedTensor, computed from the packingLoops and
280 // `backwardSlice`.
281 FailureOr<SmallVector<Value>>
282 HoistingAnalysis::getPackedTensorSizes(ImplicitLocOpBuilder &b) {
283   // Create the base affine constaints for the packedLoops.
284   auto constraints = FlatAffineValueConstraints::getHyperrectangular(
285       llvm::to_vector<8>(llvm::map_range(
286           packingLoops, [](scf::ForOp op) { return op.getInductionVar(); })),
287       llvm::to_vector<8>(llvm::map_range(
288           packingLoops, [](scf::ForOp op) { return op.lowerBound(); })),
289       llvm::to_vector<8>(llvm::map_range(
290           packingLoops, [](scf::ForOp op) { return op.upperBound(); })));
291 
292   // Iteratively try to fold the upper bounds into the constraints set.
293   if (failed(foldUpperBoundsIntoConstraintsSet(
294           constraints, outermostEnclosingForOp, packingLoops.getArrayRef())))
295     return failure();
296 
297   int nPackedLoops = packingLoops.size();
298   SmallVector<AffineMap> lbs(nPackedLoops), ubs(nPackedLoops);
299   // Compute the bounds of the first positions, assuming the others are fixed.
300   constraints.getSliceBounds(/*pos=*/0, /*num=*/nPackedLoops,
301                              outermostEnclosingForOp->getContext(), &lbs, &ubs);
302 
303   SmallVector<Value> allValues;
304   constraints.getAllValues(&allValues);
305   SmallVector<Value> allNonLoopValues(allValues.begin() + nPackedLoops,
306                                       allValues.end());
307 
308   // For each packingLoop, create the extent by (ub - lb).ceilDiv(step).
309   // IP just before the outermost loop considered that we hoist above.
310   assert(nPackedLoops == static_cast<int64_t>(lbs.size()) &&
311          "expected matching lb sizes");
312   assert(nPackedLoops == static_cast<int64_t>(ubs.size()) &&
313          "expected matching ub sizes");
314   SmallVector<Value> dynamicTensorSizes;
315   for (auto it : llvm::zip(packingLoops, lbs, ubs)) {
316     scf::ForOp loop = std::get<0>(it);
317     AffineMap lbMap = std::get<1>(it);
318     AffineMap ubMap = std::get<2>(it);
319     SmallVector<Value> lbOperands(allNonLoopValues);
320     canonicalizeMapAndOperands(&lbMap, &lbOperands);
321     Value lbVal = b.createOrFold<AffineMaxOp>(lbMap, lbOperands);
322 
323     SmallVector<Value> ubOperands(allNonLoopValues);
324     canonicalizeMapAndOperands(&ubMap, &ubOperands);
325     Value ubVal = b.createOrFold<AffineMinOp>(ubMap, ubOperands);
326 
327     AffineExpr lb, ub, step;
328     bindDims(b.getContext(), lb, ub);
329     bindSymbols(b.getContext(), step);
330     Value res = b.createOrFold<AffineApplyOp>(
331         (ub - lb).ceilDiv(step),
332         ValueRange{lbVal, ubVal, cast<scf::ForOp>(loop).step()});
333 
334     dynamicTensorSizes.push_back(res);
335   }
336   return dynamicTensorSizes;
337 }
338 
339 /// Return success if `v` is a value that is only transitively defined by ops of
340 /// type in `OpTypeList`.
341 template <typename... OpTypeList>
342 static bool backwardsSliceOnlyHasOpsOfType(scf::ForOp outerLimit, Value v) {
343   // Compute a backward slice up to, but not including, `outerLimit`.
344   SetVector<Operation *> backwardSlice;
345   getBackwardSlice(v, &backwardSlice, [&](Operation *op) {
346     return outerLimit->isProperAncestor(op);
347   });
348   // Traverse the backward slice and ensure we can perform the computation to
349   // hoist.
350   for (Operation *op : backwardSlice) {
351     if (isa<OpTypeList...>(op))
352       continue;
353     LLVM_DEBUG(DBGS() << "Abort: unadmissible op in slice " << *op << "\n");
354     return false;
355   }
356   return true;
357 }
358 
359 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
360 /// The returned Value is guaranteed not to depend on any loop comprised in
361 /// [`outer`, `forOp`].
362 /// Return null if such a loop-independent quantity cannot be computed.
363 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer,
364                                      scf::ForOp forOp) {
365   MLIRContext *ctx = forOp->getContext();
366   AffineExpr iv, lb, step;
367   bindDims(ctx, iv, lb);
368   bindSymbols(ctx, step);
369   if (!isDefinedOutsideOrConstant(outer, forOp.lowerBound()) ||
370       !isDefinedOutsideOrConstant(outer, forOp.step()))
371     return Value();
372   Value ivVal = forOp.getInductionVar(), lbVal = forOp.lowerBound(),
373         stepVal = forOp.step();
374   auto loc = forOp->getLoc();
375   return b.createOrFold<AffineApplyOp>(loc, (iv - lb).ceilDiv(step),
376                                        ValueRange{ivVal, lbVal, stepVal});
377 }
378 
379 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
380                                                   int nLoops) {
381   LLVM_DEBUG(DBGS() << "Try to hoist " << *(padTensorOp) << " by " << nLoops
382                     << " loops\n");
383   HoistingAnalysis analysis(padTensorOp, nLoops);
384   if (!analysis.isValid()) {
385     LLVM_DEBUG(DBGS() << "Analysis failed -> Skip\n");
386     return failure();
387   }
388 
389   scf::ForOp outer = analysis.outermostEnclosingForOp;
390   ImplicitLocOpBuilder b(outer->getLoc(), outer);
391 
392   auto maybeDynamicTensorSizes = analysis.getPackedTensorSizes(b);
393   if (failed(maybeDynamicTensorSizes))
394     return failure();
395   SmallVector<Value> dynamicTensorSizes = *maybeDynamicTensorSizes;
396 
397   // Update actual number of loops, which may be smaller.
398   int nPackedLoops = analysis.packingLoops.size();
399 
400   Location loc = padTensorOp->getLoc();
401   RankedTensorType paddedTensorType = padTensorOp.getResultType();
402   int paddedRank = paddedTensorType.getRank();
403 
404   // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
405   // padding.
406   SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamicSize);
407   // TODO: go grab dims when necessary, for now PadTensorOp returns a static
408   // tensor.
409   llvm::append_range(packedShape, paddedTensorType.getShape());
410   auto packedTensorType =
411       RankedTensorType::get(packedShape, paddedTensorType.getElementType());
412   Value packedTensor = b.create<linalg::InitTensorOp>(
413       loc, dynamicTensorSizes, packedTensorType.getShape(),
414       packedTensorType.getElementType());
415 
416   // Clone the operations involved in the backward slice, iteratively stepping
417   // into the loops that we encounter.
418   // The implementation proceeds in a stack-like fashion:
419   //   1. Iteratively clone and step into the loops, pushing the `packedTensor`
420   //      deeper in the stack.
421   //   2. Create a InsertSliceOp at the top of the stack.
422   //   3. Iteratively pop and yield the result of the InsertSliceOp across
423   //     the cloned loops.
424   SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
425   clonedLoopIvs.reserve(nPackedLoops);
426   leadingPackedTensorIndexings.reserve(nPackedLoops);
427   BlockAndValueMapping bvm;
428   // Insert `padTensorOp` into the backwardSlice so we clone it too.
429   analysis.backwardSlice.insert(padTensorOp);
430   // Stack step 1. iteratively clone loops and push `packedTensor`.
431   for (Operation *op : analysis.backwardSlice) {
432     // Specifically sit out in the extract_slice(packedTensor) case: this is the
433     // piece we seek to replace.
434     if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
435       if (bvm.lookupOrDefault(sliceOp.source()) == packedTensor)
436         continue;
437     auto effects = dyn_cast<MemoryEffectOpInterface>(op);
438     bool hasNoEffects = !effects || effects.hasNoEffect();
439     if (hasNoEffects &&
440         (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op))) {
441       b.clone(*op, bvm);
442       continue;
443     }
444     // TODO: support more cases as they appear.
445     auto forOp = dyn_cast<scf::ForOp>(op);
446     assert(forOp && "Expected scf::ForOp when hoisting pad ops");
447     // Unused loop, just skip it.
448     if (!analysis.packingLoops.contains(forOp))
449       continue;
450 
451     auto clonedForOp =
452         b.create<scf::ForOp>(loc, bvm.lookupOrDefault(forOp.lowerBound()),
453                              bvm.lookupOrDefault(forOp.upperBound()),
454                              bvm.lookupOrDefault(forOp.step()), packedTensor);
455     // Map the induction var, region args and results to the `clonedForOp`.
456     bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
457     bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());
458     bvm.map(forOp.getResults(), clonedForOp.getResults());
459     assert(clonedForOp->getNumRegions() == 1);
460     clonedLoopIvs.push_back(clonedForOp.getInductionVar());
461 
462     b.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
463     Value loopIndependentIterationCount =
464         buildLoopIterationCount(b, outer, clonedForOp);
465     // Assert the loop-independent iteration count can be computed.
466     if (!loopIndependentIterationCount)
467       llvm_unreachable("loop independence prerequisite not met");
468     leadingPackedTensorIndexings.push_back(loopIndependentIterationCount);
469     packedTensor = clonedForOp.getRegionIterArgs().front();
470   }
471 
472   // Stack step 2. create InsertSliceOp at the top of the stack.
473   // offsets = [clonedLoopIvs, 0 .. 0].
474   SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(),
475                                     leadingPackedTensorIndexings.end());
476   offsets.append(paddedRank, b.getIndexAttr(0));
477   // sizes = [1 .. 1, paddedShape].
478   SmallVector<OpFoldResult> sizes(nPackedLoops, b.getIndexAttr(1));
479   for (int64_t sz : paddedTensorType.getShape()) {
480     // TODO: go grab dims when necessary, for now PadTensorOp returns a static
481     // tensor.
482     assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes");
483     sizes.push_back(b.getIndexAttr(sz));
484   }
485   // strides = [1 .. 1].
486   SmallVector<OpFoldResult> strides(nPackedLoops + paddedRank,
487                                     b.getIndexAttr(1));
488 
489   Value inserted =
490       b.create<tensor::InsertSliceOp>(loc, bvm.lookup(padTensorOp.result()),
491                                       packedTensor, offsets, sizes, strides);
492 
493   // Stack step 3. iteratively pop the stack and propagate the yield.
494   Value valueToYield = inserted;
495   for (Value iv : llvm::reverse(clonedLoopIvs)) {
496     auto forOp = scf::getForInductionVarOwner(iv);
497     b.setInsertionPointToEnd(&forOp.getRegion().front());
498     b.create<scf::YieldOp>(loc, valueToYield);
499     valueToYield = forOp.getResult(0);
500   }
501 
502   // Now the packed tensor is ready, replace the original padding op by a
503   // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
504   b.setInsertionPoint(padTensorOp);
505   SmallVector<Value> loopIterationCounts = llvm::to_vector<4>(
506       llvm::map_range(analysis.packingLoops, [&](Operation *loop) {
507         return buildLoopIterationCount(b, outer, cast<scf::ForOp>(loop));
508       }));
509   // Assert all loop iteration counts can be computed.
510   if (llvm::any_of(loopIterationCounts, [](Value v) { return !v; }))
511     llvm_unreachable("loop independence prerequisite not met");
512   // offsets = [originalLoopIvs, 0 .. 0].
513   offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end());
514   offsets.append(paddedRank, b.getIndexAttr(0));
515   // sizes = [1 .. 1, paddedShape] (definedabove).
516   // strides = [1 .. 1] (defined above)
517   packedTensor =
518       scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
519   padTensorOp.replaceAllUsesWith(
520       b.create<tensor::ExtractSliceOp>(loc, padTensorOp.getResultType(),
521                                        packedTensor, offsets, sizes, strides)
522           ->getResult(0));
523 
524   Operation *toErase = padTensorOp;
525 
526   // Make the newly cloned `padTensorOp` available to the caller.
527   padTensorOp =
528       cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp());
529 
530   toErase->erase();
531 
532   return success();
533 }
534