1fd0c6f53SAlexander Belyaev //===- HoistPadding.cpp - Hoisting for tensor::PadOp ----------------------===//
2b74493ecSNicolas Vasilache //
3b74493ecSNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b74493ecSNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
5b74493ecSNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b74493ecSNicolas Vasilache //
7b74493ecSNicolas Vasilache //===----------------------------------------------------------------------===//
8b74493ecSNicolas Vasilache //
9b74493ecSNicolas Vasilache // This file implements functions concerned with hoisting padding operations.
10b74493ecSNicolas Vasilache //
11b74493ecSNicolas Vasilache //===----------------------------------------------------------------------===//
12b74493ecSNicolas Vasilache 
13b74493ecSNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
14b74493ecSNicolas Vasilache #include "mlir/Analysis/SliceAnalysis.h"
15eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
1636550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
17b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
18b74493ecSNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
198b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
20f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
21b74493ecSNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h"
22f71f9958SDiego Caballero #include "mlir/Dialect/Utils/IndexingUtils.h"
2399ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
2499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
25b74493ecSNicolas Vasilache #include "mlir/IR/AsmState.h"
26b74493ecSNicolas Vasilache #include "mlir/IR/BuiltinOps.h"
27b74493ecSNicolas Vasilache #include "mlir/IR/Dominance.h"
281f971e23SRiver Riddle #include "mlir/IR/Matchers.h"
29b74493ecSNicolas Vasilache #include "llvm/ADT/StringRef.h"
30b74493ecSNicolas Vasilache #include "llvm/Support/Debug.h"
31b74493ecSNicolas Vasilache 
32b74493ecSNicolas Vasilache using llvm::dbgs;
33b74493ecSNicolas Vasilache 
34b74493ecSNicolas Vasilache #define DEBUG_TYPE "hoist-padding"
35b74493ecSNicolas Vasilache 
36b74493ecSNicolas Vasilache #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
37b74493ecSNicolas Vasilache 
38b74493ecSNicolas Vasilache using namespace mlir;
39b74493ecSNicolas Vasilache using namespace mlir::linalg;
40b74493ecSNicolas Vasilache 
41fd0c6f53SAlexander Belyaev /// Analysis class to support tensor::PadOp hoisting across multiple enclosing
42b74493ecSNicolas Vasilache /// loops. The failure conditions are:
43b74493ecSNicolas Vasilache ///   1. Pad op has a use that is not an input of a LinalgOp.
44fd723eaaSTobias Gysi ///   2. Pad op does not have a constant padding value.
45fd723eaaSTobias Gysi ///   3. There is no immediately enclosing scf::ForOp.
46fd723eaaSTobias Gysi ///   4. The backward slice from the pad op to the scf::ForOp to hoist above
474b039063STobias Gysi ///      contains an unknown op with non index type operands, a region, or a
484b039063STobias Gysi ///      memory effect.
49fd723eaaSTobias Gysi ///   5. The backward slice from the pad op to the scf::ForOp to hoist above is
50b74493ecSNicolas Vasilache ///      empty.
51fd723eaaSTobias Gysi ///   6. The source tensor of pad op is not defined by an extract slice op.
52fd723eaaSTobias Gysi ///   7. The source tensor of the extract slice op is not defined outside of
539fbcad32STobias Gysi ///      the outermost enclosing scf::ForOp.
54fd723eaaSTobias Gysi ///   8. There is no enclosing scf::ForOp that indexes the padded data.
55b74493ecSNicolas Vasilache /// Other cases succeed and will trigger hoisting of the pad op.
56b74493ecSNicolas Vasilache struct HoistingAnalysis {
57e494278cSgysit   HoistingAnalysis(tensor::PadOp padOp, int numLoops);
58b74493ecSNicolas Vasilache 
isValidHoistingAnalysis59b74493ecSNicolas Vasilache   bool isValid() { return valid; }
60b74493ecSNicolas Vasilache 
614e2c978fSTobias Gysi   /// Footprint of the packedTensor, computed from the packingLoops.
624e2c978fSTobias Gysi   SmallVector<Value> getPackedTensorSizes(ImplicitLocOpBuilder &b);
63b74493ecSNicolas Vasilache 
64e494278cSgysit   /// The outermost loop, determined by `nLevels` above which `padOp` will
65b74493ecSNicolas Vasilache   /// be hoisted.
66b74493ecSNicolas Vasilache   scf::ForOp outermostEnclosingForOp;
67b74493ecSNicolas Vasilache 
68e494278cSgysit   /// Backward slice rooted at `padOp` and nested under
69b74493ecSNicolas Vasilache   /// `outermostEnclosingForOp`.
70b74493ecSNicolas Vasilache   SetVector<Operation *> backwardSlice;
71b74493ecSNicolas Vasilache 
72e494278cSgysit   /// The scf::ForOp immediately enclosing `padOp` such that:
73b74493ecSNicolas Vasilache   ///  1. they are nested under `outermostEnclosingForOp` (inclusive)
74b74493ecSNicolas Vasilache   ///  2. whose induction variable is used, directly or indirectly, in the
75e494278cSgysit   ///     computation of `padOp`.
76b74493ecSNicolas Vasilache   /// The span of these loops determines the footprint of the packed tensor.
77969243a0STobias Gysi   SmallVector<scf::ForOp> packingLoops;
78b74493ecSNicolas Vasilache 
79b74493ecSNicolas Vasilache private:
80e494278cSgysit   /// Drop any non-index dependencies of `padOp` and `sliceOp` from
81ed7c1fb9STobias Gysi   /// `backwardSlice`. The method follows the use-def chains of the index
82e494278cSgysit   /// operands consumed by `padOp` and `sliceOp` and drops the operations
83ed7c1fb9STobias Gysi   /// not part of this index computation. Afterwards, the filtered
84ed7c1fb9STobias Gysi   /// `backwardSlice` contains only the loops whose induction variable is used,
854b039063STobias Gysi   /// directly or indirectly, to index the padded tensor. The method returns
864b039063STobias Gysi   /// failure if the filtered backward slice contains an unexpected operation.
879fbcad32STobias Gysi   ///
889fbcad32STobias Gysi   /// Example:
899fbcad32STobias Gysi   /// ```
909fbcad32STobias Gysi   /// %source = linalg.fill(%cst, %arg0)
919fbcad32STobias Gysi   /// scf.for %i
92ed7c1fb9STobias Gysi   ///   %unrelated = linalg.fill(%cst, %arg1)    // not used to index %source!
93ed7c1fb9STobias Gysi   ///   scf.for %j (%arg2 = %unrelated)
949fbcad32STobias Gysi   ///     scf.for %k                             // not used to index %source!
959fbcad32STobias Gysi   ///       %ubi = affine.min #map(%i)
969fbcad32STobias Gysi   ///       %ubj = affine.min #map(%j)
979fbcad32STobias Gysi   ///       %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj]
98e494278cSgysit   ///       %padded_slice = tensor.pad %slice
999fbcad32STobias Gysi   /// ```
100ed7c1fb9STobias Gysi   /// dropNonIndexDependencies(%padded_slice, %slice)
101ed7c1fb9STobias Gysi   /// removes [scf.for %k, linalg.fill(%cst, %arg1)] from backwardSlice.
102e494278cSgysit   LogicalResult dropNonIndexDependencies(tensor::PadOp padOp,
1039fbcad32STobias Gysi                                          tensor::ExtractSliceOp sliceOp);
1049fbcad32STobias Gysi 
105b74493ecSNicolas Vasilache   /// Encodes whether the analysis is valid and hoisting can proceed.
106b74493ecSNicolas Vasilache   bool valid;
107b74493ecSNicolas Vasilache };
108b74493ecSNicolas Vasilache 
109e494278cSgysit /// Return true if all uses of `padOp` are an input tensor of some
110b74493ecSNicolas Vasilache /// LinalgOp.
isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp)111e494278cSgysit static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp) {
112*04235d07SJacques Pienaar   for (OpOperand &use : padOp.getResult().getUses()) {
113b74493ecSNicolas Vasilache     auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
114b74493ecSNicolas Vasilache     if (!linalgUser || !linalgUser.isInputTensor(&use)) {
115e494278cSgysit       LLVM_DEBUG(DBGS() << "Found a use of " << *(padOp)
116b74493ecSNicolas Vasilache                         << "\nthat is not an input tensor of a LinalgOp, "
117b74493ecSNicolas Vasilache                         << "cannot hoist\n"
118b74493ecSNicolas Vasilache                         << *(use.getOwner()) << "\n");
119b74493ecSNicolas Vasilache       return false;
120b74493ecSNicolas Vasilache     }
121b74493ecSNicolas Vasilache   }
122b74493ecSNicolas Vasilache   return true;
123b74493ecSNicolas Vasilache }
124b74493ecSNicolas Vasilache 
125b74493ecSNicolas Vasilache /// Return at most nLevels of immediately enclosing scf::ForOp loops.
126b74493ecSNicolas Vasilache /// Stops at the first parent that is not an scf::ForOp.
127b74493ecSNicolas Vasilache /// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm.
128b74493ecSNicolas Vasilache /// Control-flow and other containing ops with regions are not modeled atm.
129b74493ecSNicolas Vasilache static void
getAtMostNEnclosingLoops(tensor::PadOp padOp,int nLevels,SmallVector<scf::ForOp> & reverseEnclosingLoops)130e494278cSgysit getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels,
131b74493ecSNicolas Vasilache                          SmallVector<scf::ForOp> &reverseEnclosingLoops) {
13258ceae95SRiver Riddle   AsmState state(padOp->getParentOfType<func::FuncOp>());
133b74493ecSNicolas Vasilache   (void)state;
134b74493ecSNicolas Vasilache   scf::ForOp outermostEnclosingForOp = nullptr;
135e494278cSgysit   Operation *nextEnclosingOp = padOp->getParentOp();
136b74493ecSNicolas Vasilache   while (nLevels-- > 0 &&
137b74493ecSNicolas Vasilache          (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
138b74493ecSNicolas Vasilache     LLVM_DEBUG(
139b74493ecSNicolas Vasilache         DBGS() << "loops: ";
140b74493ecSNicolas Vasilache         outermostEnclosingForOp.getInductionVar().printAsOperand(dbgs(), state);
141b74493ecSNicolas Vasilache         dbgs() << "\n");
142b74493ecSNicolas Vasilache     reverseEnclosingLoops.push_back(outermostEnclosingForOp);
143b74493ecSNicolas Vasilache     nextEnclosingOp = outermostEnclosingForOp->getParentOp();
144b74493ecSNicolas Vasilache   }
145b74493ecSNicolas Vasilache }
146b74493ecSNicolas Vasilache 
147e494278cSgysit /// Returns the transposed `rankedTensorType` if `transposeVector` is non-empty.
148e494278cSgysit /// Fail if `transposeVector` is no permutation matching the tensor rank.
149e494278cSgysit static FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,ArrayRef<int64_t> transposeVector)150e494278cSgysit computeTransposedType(RankedTensorType rankedTensorType,
151e494278cSgysit                       ArrayRef<int64_t> transposeVector) {
152e494278cSgysit   if (transposeVector.empty())
153e494278cSgysit     return rankedTensorType;
154e494278cSgysit   if (!isPermutation(transposeVector) ||
155e494278cSgysit       transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
156e494278cSgysit     return failure();
157e494278cSgysit 
158e494278cSgysit   SmallVector<int64_t> transposedShape(rankedTensorType.getShape().begin(),
159e494278cSgysit                                        rankedTensorType.getShape().end());
160e494278cSgysit   applyPermutationToVector(transposedShape, transposeVector);
161e494278cSgysit 
162e494278cSgysit   using RTTBuilder = RankedTensorType::Builder;
163e494278cSgysit   RankedTensorType transposedTensorType =
164e494278cSgysit       RTTBuilder(rankedTensorType).setShape(transposedShape);
165e494278cSgysit   return transposedTensorType;
166e494278cSgysit }
167e494278cSgysit 
HoistingAnalysis(tensor::PadOp padOp,int numLoops)168e494278cSgysit HoistingAnalysis::HoistingAnalysis(tensor::PadOp padOp, int numLoops) {
169969243a0STobias Gysi   valid = false;
170b74493ecSNicolas Vasilache 
171e494278cSgysit   // Bail on any use that isn't an input of a LinalgOp.
172b74493ecSNicolas Vasilache   // Hoisting of inplace updates happens after vectorization.
173e494278cSgysit   if (!isOnlyUsedAsInputOfLinalgOp(padOp))
174b74493ecSNicolas Vasilache     return;
175b74493ecSNicolas Vasilache 
176ed7c1fb9STobias Gysi   // Get at most `numLoops` of immediately enclosing loops.
177b74493ecSNicolas Vasilache   SmallVector<scf::ForOp> reverseEnclosingLoops;
178e494278cSgysit   getAtMostNEnclosingLoops(padOp, numLoops, reverseEnclosingLoops);
179b74493ecSNicolas Vasilache   if (reverseEnclosingLoops.empty()) {
180b74493ecSNicolas Vasilache     LLVM_DEBUG(DBGS() << "No immediately enclosing loop -> skip\n");
181b74493ecSNicolas Vasilache     return;
182b74493ecSNicolas Vasilache   }
183b74493ecSNicolas Vasilache 
184b74493ecSNicolas Vasilache   outermostEnclosingForOp = reverseEnclosingLoops.back();
185b74493ecSNicolas Vasilache 
186e494278cSgysit   // Get the `sliceOp` that defines the source tensor of `padOp` and
1879fbcad32STobias Gysi   // check its source is defined outside of the outermost loop. This check
1889fbcad32STobias Gysi   // ensures the padded data is available for packing before entering the
1899fbcad32STobias Gysi   // outermost enclosing loop.
1909fbcad32STobias Gysi   //
1919fbcad32STobias Gysi   // Example:
1929fbcad32STobias Gysi   // ```
1939fbcad32STobias Gysi   // %source = linalg.fill(%cst, %arg0)
1949fbcad32STobias Gysi   // // %source is available for packing here!
1959fbcad32STobias Gysi   // scf.for %i
1969fbcad32STobias Gysi   //   scf.for %j
1979fbcad32STobias Gysi   //     scf.for %k
1989fbcad32STobias Gysi   //       %slice = tensor.extract_slice %source [%i, %j]
199e494278cSgysit   //       %padded_slice = tensor.pad %slice
2009fbcad32STobias Gysi   // ```
201*04235d07SJacques Pienaar   auto sliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
2029fbcad32STobias Gysi   if (!sliceOp) {
2039fbcad32STobias Gysi     LLVM_DEBUG(DBGS() << "Cannot find the extract slice op -> skip\n");
2049fbcad32STobias Gysi     return;
2059fbcad32STobias Gysi   }
206*04235d07SJacques Pienaar   if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
2079fbcad32STobias Gysi     LLVM_DEBUG(DBGS() << "Source not defined outside of loops -> skip\n");
2089fbcad32STobias Gysi     return;
2099fbcad32STobias Gysi   }
210b74493ecSNicolas Vasilache 
211e494278cSgysit   // Check the region of `padOp` depends on a constant only. Adding
212fd723eaaSTobias Gysi   // hoisting support for arbitrary padding regions would require cloning all
213fd723eaaSTobias Gysi   // dependencies captured by the padding region.
214e494278cSgysit   Value paddingValue = padOp.getConstantPaddingValue();
215fd723eaaSTobias Gysi   if (!paddingValue ||
216fd723eaaSTobias Gysi       !isa_and_nonnull<arith::ConstantOp>(paddingValue.getDefiningOp())) {
217fd723eaaSTobias Gysi     LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> skip\n");
218fd723eaaSTobias Gysi     return;
219fd723eaaSTobias Gysi   }
220fd723eaaSTobias Gysi 
221e494278cSgysit   // Get all the ops in the backwards slice starting from `padOp` and that
222ed7c1fb9STobias Gysi   // are dominated by the outermost enclosing loop.
223ed7c1fb9STobias Gysi   DominanceInfo domInfo(outermostEnclosingForOp);
224e494278cSgysit   getBackwardSlice(padOp.getOperation(), &backwardSlice, [&](Operation *op) {
225ed7c1fb9STobias Gysi     return domInfo.dominates(outermostEnclosingForOp, op);
226ed7c1fb9STobias Gysi   });
227ed7c1fb9STobias Gysi   if (backwardSlice.empty())
228ed7c1fb9STobias Gysi     return;
229e494278cSgysit   // Add `padOp` itself to the backward slice.
230e494278cSgysit   backwardSlice.insert(padOp.getOperation());
2319fbcad32STobias Gysi 
232ed7c1fb9STobias Gysi   // Remove all ops in the backward slice that are not used to index the padded
233e494278cSgysit   // tensor. In particular, keep `padOp`, `sliceOp`, and the loop and
234ed7c1fb9STobias Gysi   // affine operations used for the index computation.
235e494278cSgysit   if (failed(dropNonIndexDependencies(padOp, sliceOp)))
236ed7c1fb9STobias Gysi     return;
237ed7c1fb9STobias Gysi 
238ed7c1fb9STobias Gysi   // Add only the loops part of the filtered `backwardSlice` to the packing
239ed7c1fb9STobias Gysi   // loops. All other loops are not used to index the padded data and
240ed7c1fb9STobias Gysi   // consequently access the same data in every loop iteration. Adding them to
241ed7c1fb9STobias Gysi   // the packing loops would increase the cache footprint of the packed data
242ed7c1fb9STobias Gysi   // by storing the same data multiple times.
243969243a0STobias Gysi   for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops))
244ed7c1fb9STobias Gysi     if (backwardSlice.contains(forOp))
245ed7c1fb9STobias Gysi       packingLoops.push_back(forOp);
2469fbcad32STobias Gysi   if (packingLoops.empty()) {
2479fbcad32STobias Gysi     LLVM_DEBUG(DBGS() << "Cannot find a packing loop -> skip\n");
2489fbcad32STobias Gysi     return;
249b74493ecSNicolas Vasilache   }
250b74493ecSNicolas Vasilache 
251b74493ecSNicolas Vasilache   // The analysis is valid and hoisting can occur.
252b74493ecSNicolas Vasilache   valid = true;
253b74493ecSNicolas Vasilache }
254b74493ecSNicolas Vasilache 
2554b039063STobias Gysi LogicalResult
dropNonIndexDependencies(tensor::PadOp padOp,tensor::ExtractSliceOp sliceOp)256e494278cSgysit HoistingAnalysis::dropNonIndexDependencies(tensor::PadOp padOp,
2574b039063STobias Gysi                                            tensor::ExtractSliceOp sliceOp) {
2589fbcad32STobias Gysi   // Set of all values used for index computation.
2599fbcad32STobias Gysi   SetVector<Value> indexEdges;
2609fbcad32STobias Gysi 
2614e2c978fSTobias Gysi   // Add all index operands of `operation` to `indexEdges`. An index operand is
2624e2c978fSTobias Gysi   // an operand of type index.
2634e2c978fSTobias Gysi   auto addIndexOperandsToIndexEdges = [&](Operation *operation) {
2644e2c978fSTobias Gysi     for (Value operand : operation->getOperands())
2654e2c978fSTobias Gysi       if (operand.getType().isIndex())
2664e2c978fSTobias Gysi         indexEdges.insert(operand);
2674e2c978fSTobias Gysi   };
2684e2c978fSTobias Gysi 
269ed7c1fb9STobias Gysi   // Check if any operation result is contained in `indexEdges`.
270ed7c1fb9STobias Gysi   auto hasIndexResult = [&](Operation *operation) {
271ed7c1fb9STobias Gysi     return llvm::any_of(operation->getResults(), [&](Value result) {
272ed7c1fb9STobias Gysi       return indexEdges.contains(result);
273ed7c1fb9STobias Gysi     });
274ed7c1fb9STobias Gysi   };
275ed7c1fb9STobias Gysi 
276e494278cSgysit   // Starting from `padOp` and `sliceOp` walk the use-def edges of index
2779fbcad32STobias Gysi   // type in `backwardSlice`. Add the index operands of an operation to
278ed7c1fb9STobias Gysi   // `indexEdges` and remove all operations from `backwardSlice` that are not
279ed7c1fb9STobias Gysi   // part of the index computation.
2809fbcad32STobias Gysi   //
2819fbcad32STobias Gysi   // Example:
2829fbcad32STobias Gysi   // ```
2839fbcad32STobias Gysi   // %source = linalg.fill(%cst, %arg0)
2849fbcad32STobias Gysi   // scf.for %i
285ed7c1fb9STobias Gysi   //   %unrelated = linalg.fill(%cst, %arg1)    // not used to index %source!
286ed7c1fb9STobias Gysi   //   scf.for %j (%arg2 = %unrelated)
2879fbcad32STobias Gysi   //     scf.for %k                             // not used to index %source!
2889fbcad32STobias Gysi   //       %ubi = affine.min #map(%i)
2899fbcad32STobias Gysi   //       %ubj = affine.min #map(%j)
2909fbcad32STobias Gysi   //       %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj]
291e494278cSgysit   //       %padded_slice = tensor.pad %slice
2929fbcad32STobias Gysi   // ```
2939fbcad32STobias Gysi   // After iterating `backwardSlice` we obtain:
2949fbcad32STobias Gysi   // indexEdges = [%i, %j, %ubi, %ubj]
295ed7c1fb9STobias Gysi   // backwardSlice = backwardSlice / [linalg.fill(%cst, %arg1), scf.for %k]
2961ae7342aSgysit   SetVector<Operation *> operationsToRemove;
2979fbcad32STobias Gysi   for (Operation *op : llvm::reverse(backwardSlice)) {
298e494278cSgysit     // Add the index operands of `padOp` and `sliceOp` to start the
2999fbcad32STobias Gysi     // exploration of the index computation.
300e494278cSgysit     if (op == padOp || op == sliceOp) {
3014e2c978fSTobias Gysi       addIndexOperandsToIndexEdges(op);
3029fbcad32STobias Gysi       continue;
3039fbcad32STobias Gysi     }
3049fbcad32STobias Gysi     // Add the index operands of the loop if its induction variable is
305ed7c1fb9STobias Gysi     // used for index computation.
3069fbcad32STobias Gysi     if (auto forOp = dyn_cast<scf::ForOp>(op)) {
3074b039063STobias Gysi       if (!hasIndexResult(op) && indexEdges.contains(forOp.getInductionVar())) {
3084e2c978fSTobias Gysi         addIndexOperandsToIndexEdges(op);
3099fbcad32STobias Gysi         continue;
3109fbcad32STobias Gysi       }
3119fbcad32STobias Gysi     }
3129fbcad32STobias Gysi     // Add the index operands of all other operations if at least one result is
3139fbcad32STobias Gysi     // used for index computation.
314ed7c1fb9STobias Gysi     if (hasIndexResult(op)) {
3154e2c978fSTobias Gysi       addIndexOperandsToIndexEdges(op);
3164b039063STobias Gysi       // Check the operands of the remaining operations all have index type.
3174b039063STobias Gysi       if (llvm::any_of(op->getOperandTypes(),
3184b039063STobias Gysi                        [](Type type) { return !type.isIndex(); })) {
3194b039063STobias Gysi         LLVM_DEBUG(DBGS() << "Unsupported op with non index type operands: "
3204b039063STobias Gysi                           << op << " -> skip\n");
3214b039063STobias Gysi         return failure();
3224b039063STobias Gysi       }
3234b039063STobias Gysi       // Check the remaining operations do not have regions or memory effects.
3244b039063STobias Gysi       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
3254b039063STobias Gysi       bool hasMemoryEffect = effectInterface && !effectInterface.hasNoEffect();
3264b039063STobias Gysi       if (hasMemoryEffect || op->getNumRegions() != 0) {
3274b039063STobias Gysi         LLVM_DEBUG(DBGS() << "Unsupported op with region or memory effect: "
3284b039063STobias Gysi                           << op << " -> skip\n");
3294b039063STobias Gysi         return failure();
3304b039063STobias Gysi       }
331ed7c1fb9STobias Gysi       continue;
3329fbcad32STobias Gysi     }
3331ae7342aSgysit     // Remove all other operations not used by the index computation. An
334e494278cSgysit     // exception are constant operations that may be used by `padOp`.
335ed7c1fb9STobias Gysi     if (!isa<arith::ConstantOp>(op))
3361ae7342aSgysit       operationsToRemove.insert(op);
337ed7c1fb9STobias Gysi   }
3381ae7342aSgysit   backwardSlice.set_subtract(operationsToRemove);
3394b039063STobias Gysi   return success();
3409fbcad32STobias Gysi }
3419fbcad32STobias Gysi 
3424e2c978fSTobias Gysi SmallVector<Value>
getPackedTensorSizes(ImplicitLocOpBuilder & b)343b74493ecSNicolas Vasilache HoistingAnalysis::getPackedTensorSizes(ImplicitLocOpBuilder &b) {
344b74493ecSNicolas Vasilache   SmallVector<Value> dynamicTensorSizes;
345b74493ecSNicolas Vasilache 
3464e2c978fSTobias Gysi   // Upper bound the packing loop lengths to size the packed tensor. Taking
3474e2c978fSTobias Gysi   // upper bounds can make the sizes of the packed tensor independent of the
3484e2c978fSTobias Gysi   // enclosing loops. This independence is a prerequisite for reusing the same
3494e2c978fSTobias Gysi   // buffer for all enclosing loop iterations and hoisting its allocation out of
3504e2c978fSTobias Gysi   // the enclosing loops.
3514e2c978fSTobias Gysi   for (auto forOp : packingLoops) {
3524e2c978fSTobias Gysi     // Compute an upper bound `ubVal` for the upper bound of `forOp`.
3534e2c978fSTobias Gysi     AffineMap boundMap;
3544e2c978fSTobias Gysi     SmallVector<Value> boundOperands;
355c0342a2dSJacques Pienaar     getUpperBoundForIndex(forOp.getUpperBound(), boundMap, boundOperands);
3564e2c978fSTobias Gysi     Value ubVal = b.createOrFold<AffineMinOp>(boundMap, boundOperands);
3574e2c978fSTobias Gysi     // Compute the maximal packing loop length as (ub - lb).ceilDiv(step) and
3584e2c978fSTobias Gysi     // store the result to `dynamicTensorSizes`.
3594e2c978fSTobias Gysi     // TODO: instead of using the lower bound of `forOp` directly, implement a
3604e2c978fSTobias Gysi     // lower bound computation similar to the upper bound computation.
361b74493ecSNicolas Vasilache     AffineExpr lb, ub, step;
362b74493ecSNicolas Vasilache     bindDims(b.getContext(), lb, ub);
363b74493ecSNicolas Vasilache     bindSymbols(b.getContext(), step);
364b74493ecSNicolas Vasilache     Value res = b.createOrFold<AffineApplyOp>(
365c0342a2dSJacques Pienaar         (ub - lb).ceilDiv(step), ValueRange{forOp.getLowerBound(), ubVal,
366c0342a2dSJacques Pienaar                                             cast<scf::ForOp>(forOp).getStep()});
367b74493ecSNicolas Vasilache     dynamicTensorSizes.push_back(res);
368b74493ecSNicolas Vasilache   }
3694e2c978fSTobias Gysi 
370b74493ecSNicolas Vasilache   return dynamicTensorSizes;
371b74493ecSNicolas Vasilache }
372b74493ecSNicolas Vasilache 
isDefinedOutsideOrConstant(scf::ForOp outer,Value v)3734e2c978fSTobias Gysi static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
3741f971e23SRiver Riddle   return outer.isDefinedOutsideOfLoop(v) || matchPattern(v, m_Constant());
3754e2c978fSTobias Gysi }
3764e2c978fSTobias Gysi 
377b74493ecSNicolas Vasilache /// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
378b74493ecSNicolas Vasilache /// The returned Value is guaranteed not to depend on any loop comprised in
379b74493ecSNicolas Vasilache /// [`outer`, `forOp`].
380b74493ecSNicolas Vasilache /// Return null if such a loop-independent quantity cannot be computed.
buildLoopIterationCount(OpBuilder & b,scf::ForOp outer,scf::ForOp forOp)381b74493ecSNicolas Vasilache static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer,
382b74493ecSNicolas Vasilache                                      scf::ForOp forOp) {
383b74493ecSNicolas Vasilache   MLIRContext *ctx = forOp->getContext();
384b74493ecSNicolas Vasilache   AffineExpr iv, lb, step;
385b74493ecSNicolas Vasilache   bindDims(ctx, iv, lb);
386b74493ecSNicolas Vasilache   bindSymbols(ctx, step);
387c0342a2dSJacques Pienaar   if (!isDefinedOutsideOrConstant(outer, forOp.getLowerBound()) ||
388c0342a2dSJacques Pienaar       !isDefinedOutsideOrConstant(outer, forOp.getStep()))
389b74493ecSNicolas Vasilache     return Value();
390c0342a2dSJacques Pienaar   Value ivVal = forOp.getInductionVar(), lbVal = forOp.getLowerBound(),
391c0342a2dSJacques Pienaar         stepVal = forOp.getStep();
392b74493ecSNicolas Vasilache   auto loc = forOp->getLoc();
393b74493ecSNicolas Vasilache   return b.createOrFold<AffineApplyOp>(loc, (iv - lb).ceilDiv(step),
394b74493ecSNicolas Vasilache                                        ValueRange{ivVal, lbVal, stepVal});
395b74493ecSNicolas Vasilache }
396b74493ecSNicolas Vasilache 
hoistPaddingOnTensors(tensor::PadOp opToHoist,int numLoops,ArrayRef<int64_t> transposeVector,tensor::PadOp & hoistedOp,SmallVectorImpl<GenericOp> & transposeOps)397e494278cSgysit FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(
398e494278cSgysit     tensor::PadOp opToHoist, int numLoops, ArrayRef<int64_t> transposeVector,
399e494278cSgysit     tensor::PadOp &hoistedOp, SmallVectorImpl<GenericOp> &transposeOps) {
400e83d8466STobias Gysi   LLVM_DEBUG(DBGS() << "Try to hoist " << *(opToHoist) << " by " << numLoops
401b74493ecSNicolas Vasilache                     << " loops\n");
402e83d8466STobias Gysi   HoistingAnalysis analysis(opToHoist, numLoops);
403b74493ecSNicolas Vasilache   if (!analysis.isValid()) {
404b74493ecSNicolas Vasilache     LLVM_DEBUG(DBGS() << "Analysis failed -> Skip\n");
405b74493ecSNicolas Vasilache     return failure();
406b74493ecSNicolas Vasilache   }
407b74493ecSNicolas Vasilache 
408b74493ecSNicolas Vasilache   scf::ForOp outer = analysis.outermostEnclosingForOp;
409b74493ecSNicolas Vasilache   ImplicitLocOpBuilder b(outer->getLoc(), outer);
410b74493ecSNicolas Vasilache 
4114e2c978fSTobias Gysi   SmallVector<Value> dynamicTensorSizes = analysis.getPackedTensorSizes(b);
412b74493ecSNicolas Vasilache 
413b74493ecSNicolas Vasilache   // Update actual number of loops, which may be smaller.
414b74493ecSNicolas Vasilache   int nPackedLoops = analysis.packingLoops.size();
415b74493ecSNicolas Vasilache 
416e83d8466STobias Gysi   Location loc = opToHoist->getLoc();
417e83d8466STobias Gysi   RankedTensorType paddedTensorType = opToHoist.getResultType();
418b74493ecSNicolas Vasilache   int paddedRank = paddedTensorType.getRank();
419b74493ecSNicolas Vasilache 
420e494278cSgysit   // Compute the type of the transposed padded tensor.
421e494278cSgysit   FailureOr<RankedTensorType> transposedTensorType =
422e494278cSgysit       computeTransposedType(paddedTensorType, transposeVector);
423e494278cSgysit   if (failed(transposedTensorType))
424e494278cSgysit     return failure();
425e494278cSgysit 
426e494278cSgysit   // Create the packed tensor<?x?x..?xtransposedShape> into which we amortize
427b74493ecSNicolas Vasilache   // padding.
428b74493ecSNicolas Vasilache   SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamicSize);
429fd0c6f53SAlexander Belyaev   // TODO: go grab dims when necessary, for now tensor::PadOp returns a static
430b74493ecSNicolas Vasilache   // tensor.
431e494278cSgysit   llvm::append_range(packedShape, transposedTensorType->getShape());
432e494278cSgysit   auto packedTensorType = RankedTensorType::get(
433e494278cSgysit       packedShape, transposedTensorType->getElementType());
434b74493ecSNicolas Vasilache   Value packedTensor = b.create<linalg::InitTensorOp>(
435b74493ecSNicolas Vasilache       loc, dynamicTensorSizes, packedTensorType.getShape(),
436b74493ecSNicolas Vasilache       packedTensorType.getElementType());
437b74493ecSNicolas Vasilache 
438b74493ecSNicolas Vasilache   // Clone the operations involved in the backward slice, iteratively stepping
439b74493ecSNicolas Vasilache   // into the loops that we encounter.
440b74493ecSNicolas Vasilache   // The implementation proceeds in a stack-like fashion:
441b74493ecSNicolas Vasilache   //   1. Iteratively clone and step into the loops, pushing the `packedTensor`
442b74493ecSNicolas Vasilache   //      deeper in the stack.
443e494278cSgysit   //   2. Create a GenericOp if `transposeVector` is non-empty.
444e494278cSgysit   //   3. Create a InsertSliceOp at the top of the stack.
445e494278cSgysit   //   4. Iteratively pop and yield the result of the InsertSliceOp across
446b74493ecSNicolas Vasilache   //      the cloned loops.
447b74493ecSNicolas Vasilache   SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
448b74493ecSNicolas Vasilache   clonedLoopIvs.reserve(nPackedLoops);
449b74493ecSNicolas Vasilache   leadingPackedTensorIndexings.reserve(nPackedLoops);
450b74493ecSNicolas Vasilache   BlockAndValueMapping bvm;
451b74493ecSNicolas Vasilache   // Stack step 1. iteratively clone loops and push `packedTensor`.
452b74493ecSNicolas Vasilache   for (Operation *op : analysis.backwardSlice) {
453b74493ecSNicolas Vasilache     // Specifically sit out in the extract_slice(packedTensor) case: this is the
454b74493ecSNicolas Vasilache     // piece we seek to replace.
455b74493ecSNicolas Vasilache     if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
456*04235d07SJacques Pienaar       if (bvm.lookupOrDefault(sliceOp.getSource()) == packedTensor)
457b74493ecSNicolas Vasilache         continue;
4584b039063STobias Gysi     // Clone all operations except it is a loop.
4594b039063STobias Gysi     auto forOp = dyn_cast<scf::ForOp>(op);
4604b039063STobias Gysi     if (!forOp) {
461b74493ecSNicolas Vasilache       b.clone(*op, bvm);
462b74493ecSNicolas Vasilache       continue;
463b74493ecSNicolas Vasilache     }
4644b039063STobias Gysi     // Create a packing loop that takes `packedTensor` as iteration argument.
465c0342a2dSJacques Pienaar     auto clonedForOp = b.create<scf::ForOp>(
466c0342a2dSJacques Pienaar         loc, bvm.lookupOrDefault(forOp.getLowerBound()),
467c0342a2dSJacques Pienaar         bvm.lookupOrDefault(forOp.getUpperBound()),
468c0342a2dSJacques Pienaar         bvm.lookupOrDefault(forOp.getStep()), packedTensor);
469b74493ecSNicolas Vasilache     // Map the induction var, region args and results to the `clonedForOp`.
470b74493ecSNicolas Vasilache     bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
471b74493ecSNicolas Vasilache     bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());
472b74493ecSNicolas Vasilache     bvm.map(forOp.getResults(), clonedForOp.getResults());
473b74493ecSNicolas Vasilache     assert(clonedForOp->getNumRegions() == 1);
474b74493ecSNicolas Vasilache     clonedLoopIvs.push_back(clonedForOp.getInductionVar());
475b74493ecSNicolas Vasilache 
476b74493ecSNicolas Vasilache     b.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
477b74493ecSNicolas Vasilache     Value loopIndependentIterationCount =
478b74493ecSNicolas Vasilache         buildLoopIterationCount(b, outer, clonedForOp);
479b74493ecSNicolas Vasilache     // Assert the loop-independent iteration count can be computed.
480b74493ecSNicolas Vasilache     if (!loopIndependentIterationCount)
481b74493ecSNicolas Vasilache       llvm_unreachable("loop independence prerequisite not met");
482b74493ecSNicolas Vasilache     leadingPackedTensorIndexings.push_back(loopIndependentIterationCount);
483b74493ecSNicolas Vasilache     packedTensor = clonedForOp.getRegionIterArgs().front();
484b74493ecSNicolas Vasilache   }
485b74493ecSNicolas Vasilache 
486b74493ecSNicolas Vasilache   // offsets = [clonedLoopIvs, 0 .. 0].
487b74493ecSNicolas Vasilache   SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(),
488b74493ecSNicolas Vasilache                                     leadingPackedTensorIndexings.end());
489b74493ecSNicolas Vasilache   offsets.append(paddedRank, b.getIndexAttr(0));
490e494278cSgysit   // sizes = [1 .. 1, transposedShape].
491b74493ecSNicolas Vasilache   SmallVector<OpFoldResult> sizes(nPackedLoops, b.getIndexAttr(1));
492e494278cSgysit   for (int64_t sz : transposedTensorType->getShape()) {
493fd0c6f53SAlexander Belyaev     // TODO: go grab dims when necessary, for now tensor::PadOp returns a static
494b74493ecSNicolas Vasilache     assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes");
495b74493ecSNicolas Vasilache     sizes.push_back(b.getIndexAttr(sz));
496b74493ecSNicolas Vasilache   }
497b74493ecSNicolas Vasilache   // strides = [1 .. 1].
498b74493ecSNicolas Vasilache   SmallVector<OpFoldResult> strides(nPackedLoops + paddedRank,
499b74493ecSNicolas Vasilache                                     b.getIndexAttr(1));
500b74493ecSNicolas Vasilache 
501e494278cSgysit   // Stack step 2. create GenericOp if `transposeVector` is non-empty.
502*04235d07SJacques Pienaar   Value paddedTensor = bvm.lookup(opToHoist.getResult());
503e494278cSgysit   if (!transposeVector.empty()) {
504e494278cSgysit     Value outputTensor = b.create<tensor::ExtractSliceOp>(
505e494278cSgysit         loc, *transposedTensorType, packedTensor, offsets, sizes, strides);
506e494278cSgysit     transposeOps.push_back(
507e494278cSgysit         makeTransposeOp(b, loc, paddedTensor, outputTensor, transposeVector));
508e494278cSgysit     paddedTensor = transposeOps.back()->getResult(0);
509e494278cSgysit   }
510b74493ecSNicolas Vasilache 
511e494278cSgysit   // Stack step 3. create InsertSliceOp at the top of the stack.
512e494278cSgysit   Value inserted = b.create<tensor::InsertSliceOp>(
513e494278cSgysit       loc, paddedTensor, packedTensor, offsets, sizes, strides);
514e494278cSgysit 
515e494278cSgysit   // Stack step 4. iteratively pop the stack and propagate the yield.
516b74493ecSNicolas Vasilache   Value valueToYield = inserted;
517b74493ecSNicolas Vasilache   for (Value iv : llvm::reverse(clonedLoopIvs)) {
518b74493ecSNicolas Vasilache     auto forOp = scf::getForInductionVarOwner(iv);
519b74493ecSNicolas Vasilache     b.setInsertionPointToEnd(&forOp.getRegion().front());
520b74493ecSNicolas Vasilache     b.create<scf::YieldOp>(loc, valueToYield);
521b74493ecSNicolas Vasilache     valueToYield = forOp.getResult(0);
522b74493ecSNicolas Vasilache   }
523b74493ecSNicolas Vasilache 
524b74493ecSNicolas Vasilache   // Now the packed tensor is ready, replace the original padding op by a
525b74493ecSNicolas Vasilache   // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
526e83d8466STobias Gysi   b.setInsertionPoint(opToHoist);
527b74493ecSNicolas Vasilache   SmallVector<Value> loopIterationCounts = llvm::to_vector<4>(
528b74493ecSNicolas Vasilache       llvm::map_range(analysis.packingLoops, [&](Operation *loop) {
529b74493ecSNicolas Vasilache         return buildLoopIterationCount(b, outer, cast<scf::ForOp>(loop));
530b74493ecSNicolas Vasilache       }));
531b74493ecSNicolas Vasilache   // Assert all loop iteration counts can be computed.
532b74493ecSNicolas Vasilache   if (llvm::any_of(loopIterationCounts, [](Value v) { return !v; }))
533b74493ecSNicolas Vasilache     llvm_unreachable("loop independence prerequisite not met");
534b74493ecSNicolas Vasilache   // offsets = [originalLoopIvs, 0 .. 0].
535b74493ecSNicolas Vasilache   offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end());
536b74493ecSNicolas Vasilache   offsets.append(paddedRank, b.getIndexAttr(0));
537e494278cSgysit   // sizes = [1 .. 1, transposedShape] (definedabove).
538b74493ecSNicolas Vasilache   // strides = [1 .. 1] (defined above)
539b74493ecSNicolas Vasilache   packedTensor =
540b74493ecSNicolas Vasilache       scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
541e83d8466STobias Gysi   Value newResult = b.create<tensor::ExtractSliceOp>(
542e494278cSgysit       loc, *transposedTensorType, packedTensor, offsets, sizes, strides);
543e494278cSgysit 
544e494278cSgysit   // Transpose the packed tensor back to the original storage order.
545e494278cSgysit   if (!transposeVector.empty()) {
546e494278cSgysit     Value initTensor =
547e494278cSgysit         b.create<InitTensorOp>(loc, ValueRange{}, paddedTensorType.getShape(),
548e494278cSgysit                                paddedTensorType.getElementType());
549e494278cSgysit     transposeOps.push_back(
550e494278cSgysit         makeTransposeOp(b, loc, newResult, initTensor, transposeVector));
551e494278cSgysit     newResult = transposeOps.back()->getResult(0);
552e494278cSgysit   }
553b74493ecSNicolas Vasilache 
554e83d8466STobias Gysi   // Make the newly cloned `opToHoist` available to the caller.
555fd0c6f53SAlexander Belyaev   hoistedOp =
556*04235d07SJacques Pienaar       cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp());
557e83d8466STobias Gysi   return newResult;
558b74493ecSNicolas Vasilache }
559