1 //===- HoistPadding.cpp - Hoisting transformation for PadTensorOp ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting padding operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/SCF/SCF.h"
19 #include "mlir/Dialect/SCF/Utils.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Dialect/Vector/VectorUtils.h"
24 #include "mlir/IR/AsmState.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Dominance.h"
27 #include "mlir/Transforms/LoopUtils.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Debug.h"
30 
31 using llvm::dbgs;
32 
33 #define DEBUG_TYPE "hoist-padding"
34 
35 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
36 
37 using namespace mlir;
38 using namespace mlir::linalg;
39 
40 /// Analysis class to support PadTensorOp hoisting across multiple enclosing
41 /// loops. The failure conditions are:
42 ///   1. Pad op has a use that is not an input of a LinalgOp.
43 ///   2. There is no immediately enclosing scf::ForOp.
44 ///   3. The backward slice from the pad op to the scf::ForOp to hoist above
45 ///      contains an unknown op with a region.
46 ///   4. The backward slice from the pad op to the scf::ForOp to hoist above is
47 ///      empty.
48 ///   5. The source tensor of pad op is not defined by an extract slice op.
49 ///   6. The source tensor of the extract slice op is not defined outside of
50 ///      the outermost enclosing scf::ForOp.
51 ///   7. There is no enclosing scf::ForOp that indexes the padded data.
52 /// Other cases succeed and will trigger hoisting of the pad op.
53 struct HoistingAnalysis {
54   HoistingAnalysis(PadTensorOp padTensorOp, int numLoops);
55 
56   bool isValid() { return valid; }
57 
58   /// Footprint of the packedTensor, computed from the packingLoops.
59   SmallVector<Value> getPackedTensorSizes(ImplicitLocOpBuilder &b);
60 
61   /// The outermost loop, determined by `nLevels` above which `padTensorOp` will
62   /// be hoisted.
63   scf::ForOp outermostEnclosingForOp;
64 
65   /// Backward slice rooted at `padTensorOp` and nested under
66   /// `outermostEnclosingForOp`.
67   SetVector<Operation *> backwardSlice;
68 
69   /// The scf::ForOp immediately enclosing `padTensorOp` such that:
70   ///  1. they are nested under `outermostEnclosingForOp` (inclusive)
71   ///  2. whose induction variable is used, directly or indirectly, in the
72   ///     computation of `padTensorOp`.
73   /// The span of these loops determines the footprint of the packed tensor.
74   SmallVector<scf::ForOp> packingLoops;
75 
76 private:
77   /// Returns the loops in `backwardSlice` used to index the padded data. The
78   /// method starts from `padTensorOp` and `sliceOp`, follows the use-def
79   /// chains of their index operands, and stores any enclosing loop whose
80   /// induction variable is part of the walked index computation.
81   ///
82   /// Example:
83   /// ```
84   /// %source = linalg.fill(%cst, %arg0)
85   /// scf.for %i
86   ///   scf.for %j
87   ///     scf.for %k // not used to index %source!
88   ///       %ubi = affine.min #map(%i)
89   ///       %ubj = affine.min #map(%j)
90   ///       %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj]
91   ///       %padded_slice = linalg.pad_tensor %slice
92   /// ```
93   /// getIndexingLoops(%padded_slice, %slice) returns [scf.for %i, scf.for %j]
94   SmallVector<scf::ForOp> getIndexingLoops(PadTensorOp padTensorOp,
95                                            tensor::ExtractSliceOp sliceOp);
96 
97   /// Encodes whether the analysis is valid and hoisting can proceed.
98   bool valid;
99 };
100 
101 /// Return true if all uses of `padTensorOp` are an input tensor of some
102 /// LinalgOp.
103 static bool isOnlyUsedAsInputOfLinalgOp(PadTensorOp padTensorOp) {
104   for (OpOperand &use : padTensorOp.result().getUses()) {
105     auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
106     if (!linalgUser || !linalgUser.isInputTensor(&use)) {
107       LLVM_DEBUG(DBGS() << "Found a use of " << *(padTensorOp)
108                         << "\nthat is not an input tensor of a LinalgOp, "
109                         << "cannot hoist\n"
110                         << *(use.getOwner()) << "\n");
111       return false;
112     }
113   }
114   return true;
115 }
116 
117 /// Return at most nLevels of immediately enclosing scf::ForOp loops.
118 /// Stops at the first parent that is not an scf::ForOp.
119 /// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm.
120 /// Control-flow and other containing ops with regions are not modeled atm.
121 static void
122 getAtMostNEnclosingLoops(PadTensorOp padTensorOp, int nLevels,
123                          SmallVector<scf::ForOp> &reverseEnclosingLoops) {
124   AsmState state(padTensorOp->getParentOfType<mlir::FuncOp>());
125   (void)state;
126   scf::ForOp outermostEnclosingForOp = nullptr;
127   Operation *nextEnclosingOp = padTensorOp->getParentOp();
128   while (nLevels-- > 0 &&
129          (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
130     LLVM_DEBUG(
131         DBGS() << "loops: ";
132         outermostEnclosingForOp.getInductionVar().printAsOperand(dbgs(), state);
133         dbgs() << "\n");
134     reverseEnclosingLoops.push_back(outermostEnclosingForOp);
135     nextEnclosingOp = outermostEnclosingForOp->getParentOp();
136   }
137 }
138 
139 HoistingAnalysis::HoistingAnalysis(PadTensorOp padTensorOp, int numLoops) {
140   valid = false;
141 
142   // Bail on any use that isn't an input of a Linalg op.
143   // Hoisting of inplace updates happens after vectorization.
144   if (!isOnlyUsedAsInputOfLinalgOp(padTensorOp))
145     return;
146 
147   // Get at most nLevels of immediately enclosing loops.
148   SmallVector<scf::ForOp> reverseEnclosingLoops;
149   getAtMostNEnclosingLoops(padTensorOp, numLoops, reverseEnclosingLoops);
150   if (reverseEnclosingLoops.empty()) {
151     LLVM_DEBUG(DBGS() << "No immediately enclosing loop -> skip\n");
152     return;
153   }
154 
155   outermostEnclosingForOp = reverseEnclosingLoops.back();
156 
157   // Get all the ops in the backwards slice starting from `padTensorOp` and that
158   // are dominated by the outermost enclosing loop.
159   // Bail on any op with a region that is not either a scf::ForOp or a LinalgOp.
160   bool analysisFailure = false;
161   DominanceInfo domInfo(outermostEnclosingForOp);
162   getBackwardSlice(
163       padTensorOp.getOperation(), &backwardSlice, [&](Operation *op) {
164         if (!domInfo.dominates(outermostEnclosingForOp, op))
165           return false;
166         if (op != padTensorOp && op->getNumRegions() > 0 &&
167             !isa<scf::ForOp, LinalgOp>(op)) {
168           analysisFailure = true;
169           LLVM_DEBUG(DBGS()
170                      << "Unsupported op with region: " << *op << " -> skip\n");
171           return false;
172         }
173         return true;
174       });
175 
176   if (analysisFailure || backwardSlice.empty())
177     return;
178 
179   // Get the `sliceOp` that defines the source tensor of `padTensorOp` and
180   // check its source is defined outside of the outermost loop. This check
181   // ensures the padded data is available for packing before entering the
182   // outermost enclosing loop.
183   //
184   // Example:
185   // ```
186   // %source = linalg.fill(%cst, %arg0)
187   // // %source is available for packing here!
188   // scf.for %i
189   //   scf.for %j
190   //     scf.for %k
191   //       %slice = tensor.extract_slice %source [%i, %j]
192   //       %padded_slice = linalg.pad_tensor %slice
193   // ```
194   auto sliceOp = padTensorOp.source().getDefiningOp<tensor::ExtractSliceOp>();
195   if (!sliceOp) {
196     LLVM_DEBUG(DBGS() << "Cannot find the extract slice op -> skip\n");
197     return;
198   }
199   if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.source())) {
200     LLVM_DEBUG(DBGS() << "Source not defined outside of loops -> skip\n");
201     return;
202   }
203 
204   // Search the loops found in `backwardSlice` used to index the padded data.
205   SmallVector<scf::ForOp> indexingLoops =
206       getIndexingLoops(padTensorOp, sliceOp);
207 
208   // Add only the loops part of `indexingLoops` to the packing loops. All other
209   // loops are not used to index the padded data and consequently access the
210   // same data in every loop iteration. Adding them to the packing loops would
211   // increase the cache footprint of the packed data by storing the same data
212   // multiple times.
213   for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops))
214     if (!indexingLoops.empty() && indexingLoops.back() == forOp)
215       packingLoops.push_back(indexingLoops.pop_back_val());
216   assert(indexingLoops.empty() &&
217          "expect the all indexing loops are enclosing loops");
218 
219   if (packingLoops.empty()) {
220     LLVM_DEBUG(DBGS() << "Cannot find a packing loop -> skip\n");
221     return;
222   }
223 
224   // The analysis is valid and hoisting can occur.
225   valid = true;
226 }
227 
228 SmallVector<scf::ForOp>
229 HoistingAnalysis::getIndexingLoops(PadTensorOp padTensorOp,
230                                    tensor::ExtractSliceOp sliceOp) {
231   // Set of all values used for index computation.
232   SetVector<Value> indexEdges;
233 
234   // Add all index operands of `operation` to `indexEdges`. An index operand is
235   // an operand of type index.
236   auto addIndexOperandsToIndexEdges = [&](Operation *operation) {
237     for (Value operand : operation->getOperands())
238       if (operand.getType().isIndex())
239         indexEdges.insert(operand);
240   };
241 
242   // Starting from `padTensorOp` and `sliceOp` walk the use-def edges of index
243   // type in `backwardSlice`. Add the index operands of an operation to
244   // `indexEdges` if one of its results is an index edge found so far and store
245   // all loops part of the index computation to `indexingLoops`.
246   //
247   // Example:
248   // ```
249   // %source = linalg.fill(%cst, %arg0)
250   // scf.for %i
251   //   scf.for %j
252   //     scf.for %k // not used to index %source!
253   //       %ubi = affine.min #map(%i)
254   //       %ubj = affine.min #map(%j)
255   //       %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj]
256   //       %padded_slice = linalg.pad_tensor %slice
257   // ```
258   // After iterating `backwardSlice` we obtain:
259   // indexEdges = [%i, %j, %ubi, %ubj]
260   // indexingLoops = [scf.for %i, scf.for %j]
261   SmallVector<scf::ForOp> indexingLoops;
262   for (Operation *op : llvm::reverse(backwardSlice)) {
263     // Add the index operands of `padTensorOp` and `sliceOp` to start the
264     // exploration of the index computation.
265     if (op == padTensorOp || op == sliceOp) {
266       addIndexOperandsToIndexEdges(op);
267       continue;
268     }
269     // Add the index operands of the loop if its induction variable is
270     // used for index computation. Additionally, insert the loop into
271     // `indexingLoops`
272     if (auto forOp = dyn_cast<scf::ForOp>(op)) {
273       if (indexEdges.contains(forOp.getInductionVar())) {
274         addIndexOperandsToIndexEdges(op);
275         indexingLoops.push_back(forOp);
276         continue;
277       }
278     }
279     // Add the index operands of all other operations if at least one result is
280     // used for index computation.
281     if (llvm::any_of(op->getResults(),
282                      [&](Value result) { return indexEdges.contains(result); }))
283       addIndexOperandsToIndexEdges(op);
284   }
285   return indexingLoops;
286 }
287 
288 SmallVector<Value>
289 HoistingAnalysis::getPackedTensorSizes(ImplicitLocOpBuilder &b) {
290   SmallVector<Value> dynamicTensorSizes;
291 
292   // Upper bound the packing loop lengths to size the packed tensor. Taking
293   // upper bounds can make the sizes of the packed tensor independent of the
294   // enclosing loops. This independence is a prerequisite for reusing the same
295   // buffer for all enclosing loop iterations and hoisting its allocation out of
296   // the enclosing loops.
297   for (auto forOp : packingLoops) {
298     // Compute an upper bound `ubVal` for the upper bound of `forOp`.
299     AffineMap boundMap;
300     SmallVector<Value> boundOperands;
301     getUpperBoundForIndex(forOp.upperBound(), boundMap, boundOperands);
302     Value ubVal = b.createOrFold<AffineMinOp>(boundMap, boundOperands);
303     // Compute the maximal packing loop length as (ub - lb).ceilDiv(step) and
304     // store the result to `dynamicTensorSizes`.
305     // TODO: instead of using the lower bound of `forOp` directly, implement a
306     // lower bound computation similar to the upper bound computation.
307     AffineExpr lb, ub, step;
308     bindDims(b.getContext(), lb, ub);
309     bindSymbols(b.getContext(), step);
310     Value res = b.createOrFold<AffineApplyOp>(
311         (ub - lb).ceilDiv(step),
312         ValueRange{forOp.lowerBound(), ubVal, cast<scf::ForOp>(forOp).step()});
313     dynamicTensorSizes.push_back(res);
314   }
315 
316   return dynamicTensorSizes;
317 }
318 
319 static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
320   return outer.isDefinedOutsideOfLoop(v) || v.getDefiningOp<ConstantOp>();
321 }
322 
323 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
324 /// The returned Value is guaranteed not to depend on any loop comprised in
325 /// [`outer`, `forOp`].
326 /// Return null if such a loop-independent quantity cannot be computed.
327 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer,
328                                      scf::ForOp forOp) {
329   MLIRContext *ctx = forOp->getContext();
330   AffineExpr iv, lb, step;
331   bindDims(ctx, iv, lb);
332   bindSymbols(ctx, step);
333   if (!isDefinedOutsideOrConstant(outer, forOp.lowerBound()) ||
334       !isDefinedOutsideOrConstant(outer, forOp.step()))
335     return Value();
336   Value ivVal = forOp.getInductionVar(), lbVal = forOp.lowerBound(),
337         stepVal = forOp.step();
338   auto loc = forOp->getLoc();
339   return b.createOrFold<AffineApplyOp>(loc, (iv - lb).ceilDiv(step),
340                                        ValueRange{ivVal, lbVal, stepVal});
341 }
342 
343 FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(PadTensorOp opToHoist,
344                                                      int numLoops,
345                                                      PadTensorOp &hoistedOp) {
346   LLVM_DEBUG(DBGS() << "Try to hoist " << *(opToHoist) << " by " << numLoops
347                     << " loops\n");
348   HoistingAnalysis analysis(opToHoist, numLoops);
349   if (!analysis.isValid()) {
350     LLVM_DEBUG(DBGS() << "Analysis failed -> Skip\n");
351     return failure();
352   }
353 
354   scf::ForOp outer = analysis.outermostEnclosingForOp;
355   ImplicitLocOpBuilder b(outer->getLoc(), outer);
356 
357   SmallVector<Value> dynamicTensorSizes = analysis.getPackedTensorSizes(b);
358 
359   // Update actual number of loops, which may be smaller.
360   int nPackedLoops = analysis.packingLoops.size();
361 
362   Location loc = opToHoist->getLoc();
363   RankedTensorType paddedTensorType = opToHoist.getResultType();
364   int paddedRank = paddedTensorType.getRank();
365 
366   // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
367   // padding.
368   SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamicSize);
369   // TODO: go grab dims when necessary, for now PadTensorOp returns a static
370   // tensor.
371   llvm::append_range(packedShape, paddedTensorType.getShape());
372   auto packedTensorType =
373       RankedTensorType::get(packedShape, paddedTensorType.getElementType());
374   Value packedTensor = b.create<linalg::InitTensorOp>(
375       loc, dynamicTensorSizes, packedTensorType.getShape(),
376       packedTensorType.getElementType());
377 
378   // Clone the operations involved in the backward slice, iteratively stepping
379   // into the loops that we encounter.
380   // The implementation proceeds in a stack-like fashion:
381   //   1. Iteratively clone and step into the loops, pushing the `packedTensor`
382   //      deeper in the stack.
383   //   2. Create a InsertSliceOp at the top of the stack.
384   //   3. Iteratively pop and yield the result of the InsertSliceOp across
385   //     the cloned loops.
386   SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
387   clonedLoopIvs.reserve(nPackedLoops);
388   leadingPackedTensorIndexings.reserve(nPackedLoops);
389   BlockAndValueMapping bvm;
390   // Insert `opToHoist` into the backwardSlice so we clone it too.
391   analysis.backwardSlice.insert(opToHoist);
392   // Stack step 1. iteratively clone loops and push `packedTensor`.
393   for (Operation *op : analysis.backwardSlice) {
394     // Specifically sit out in the extract_slice(packedTensor) case: this is the
395     // piece we seek to replace.
396     if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
397       if (bvm.lookupOrDefault(sliceOp.source()) == packedTensor)
398         continue;
399     auto effects = dyn_cast<MemoryEffectOpInterface>(op);
400     bool hasNoEffects = !effects || effects.hasNoEffect();
401     if (hasNoEffects &&
402         (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op))) {
403       b.clone(*op, bvm);
404       continue;
405     }
406     // TODO: support more cases as they appear.
407     auto forOp = dyn_cast<scf::ForOp>(op);
408     assert(forOp && "Expected scf::ForOp when hoisting pad ops");
409     // Unused loop, just skip it.
410     if (!llvm::is_contained(analysis.packingLoops, forOp))
411       continue;
412 
413     auto clonedForOp =
414         b.create<scf::ForOp>(loc, bvm.lookupOrDefault(forOp.lowerBound()),
415                              bvm.lookupOrDefault(forOp.upperBound()),
416                              bvm.lookupOrDefault(forOp.step()), packedTensor);
417     // Map the induction var, region args and results to the `clonedForOp`.
418     bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
419     bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());
420     bvm.map(forOp.getResults(), clonedForOp.getResults());
421     assert(clonedForOp->getNumRegions() == 1);
422     clonedLoopIvs.push_back(clonedForOp.getInductionVar());
423 
424     b.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
425     Value loopIndependentIterationCount =
426         buildLoopIterationCount(b, outer, clonedForOp);
427     // Assert the loop-independent iteration count can be computed.
428     if (!loopIndependentIterationCount)
429       llvm_unreachable("loop independence prerequisite not met");
430     leadingPackedTensorIndexings.push_back(loopIndependentIterationCount);
431     packedTensor = clonedForOp.getRegionIterArgs().front();
432   }
433 
434   // Stack step 2. create InsertSliceOp at the top of the stack.
435   // offsets = [clonedLoopIvs, 0 .. 0].
436   SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(),
437                                     leadingPackedTensorIndexings.end());
438   offsets.append(paddedRank, b.getIndexAttr(0));
439   // sizes = [1 .. 1, paddedShape].
440   SmallVector<OpFoldResult> sizes(nPackedLoops, b.getIndexAttr(1));
441   for (int64_t sz : paddedTensorType.getShape()) {
442     // TODO: go grab dims when necessary, for now PadTensorOp returns a static
443     // tensor.
444     assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes");
445     sizes.push_back(b.getIndexAttr(sz));
446   }
447   // strides = [1 .. 1].
448   SmallVector<OpFoldResult> strides(nPackedLoops + paddedRank,
449                                     b.getIndexAttr(1));
450 
451   Value inserted =
452       b.create<tensor::InsertSliceOp>(loc, bvm.lookup(opToHoist.result()),
453                                       packedTensor, offsets, sizes, strides);
454 
455   // Stack step 3. iteratively pop the stack and propagate the yield.
456   Value valueToYield = inserted;
457   for (Value iv : llvm::reverse(clonedLoopIvs)) {
458     auto forOp = scf::getForInductionVarOwner(iv);
459     b.setInsertionPointToEnd(&forOp.getRegion().front());
460     b.create<scf::YieldOp>(loc, valueToYield);
461     valueToYield = forOp.getResult(0);
462   }
463 
464   // Now the packed tensor is ready, replace the original padding op by a
465   // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
466   b.setInsertionPoint(opToHoist);
467   SmallVector<Value> loopIterationCounts = llvm::to_vector<4>(
468       llvm::map_range(analysis.packingLoops, [&](Operation *loop) {
469         return buildLoopIterationCount(b, outer, cast<scf::ForOp>(loop));
470       }));
471   // Assert all loop iteration counts can be computed.
472   if (llvm::any_of(loopIterationCounts, [](Value v) { return !v; }))
473     llvm_unreachable("loop independence prerequisite not met");
474   // offsets = [originalLoopIvs, 0 .. 0].
475   offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end());
476   offsets.append(paddedRank, b.getIndexAttr(0));
477   // sizes = [1 .. 1, paddedShape] (definedabove).
478   // strides = [1 .. 1] (defined above)
479   packedTensor =
480       scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
481   Value newResult = b.create<tensor::ExtractSliceOp>(
482       loc, opToHoist.getResultType(), packedTensor, offsets, sizes, strides);
483 
484   // Make the newly cloned `opToHoist` available to the caller.
485   hoistedOp = cast<PadTensorOp>(bvm.lookup(opToHoist.result()).getDefiningOp());
486   return newResult;
487 }
488