1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/AffineStructures.h"
16 #include "mlir/Analysis/SliceAnalysis.h"
17 #include "mlir/Dialect/Affine/Utils.h"
18 #include "mlir/Dialect/Linalg/Analysis/ConstraintsSet.h"
19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/SCF/SCF.h"
22 #include "mlir/Dialect/SCF/Utils.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/VectorOps.h"
26 #include "mlir/Dialect/Vector/VectorUtils.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/Dominance.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "mlir/Transforms/LoopUtils.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/Debug.h"
33 
34 using llvm::dbgs;
35 
36 #define DEBUG_TYPE "linalg-hoisting"
37 
38 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
39 
40 using namespace mlir;
41 using namespace mlir::linalg;
42 
43 namespace {
44 /// Represents a unit of hoistable TransferWriteOp. This may comprise other
45 /// instructions that need to be hoisted too.
46 struct HoistableWrite {
47   vector::TransferWriteOp transferWriteOp;
48   tensor::InsertSliceOp insertSliceOp;
49 };
50 /// Represents a unit of hoistable TransferReadOp. This may comprise other
51 /// instructions that need to be hoisted too.
52 struct HoistableRead {
53   vector::TransferReadOp transferReadOp;
54   tensor::ExtractSliceOp extractSliceOp;
55 };
56 } // namespace
57 
58 /// Return true if op1 and op2 are the same constant or the same SSA value.
59 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
60   auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
61     Attribute attr = ofr.dyn_cast<Attribute>();
62     // Note: isa+cast-like pattern allows writing the condition below as 1 line.
63     if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
64       attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
65     if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
66       return intAttr.getValue().getSExtValue();
67     return llvm::None;
68   };
69   auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
70   if (cst1 && cst2 && *cst1 == *cst2)
71     return true;
72   auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
73   return v1 && v2 && v1 == v2;
74 }
75 
76 /// Return true is all offsets, sizes and strides are equal.
77 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s,
78                                        tensor::InsertSliceOp si) {
79   if (s.static_offsets().size() != si.static_offsets().size())
80     return false;
81   if (s.static_sizes().size() != si.static_sizes().size())
82     return false;
83   if (s.static_strides().size() != si.static_strides().size())
84     return false;
85   for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
86     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
87       return false;
88   for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
89     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
90       return false;
91   for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
92     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
93       return false;
94   return true;
95 }
96 
97 /// Look for a HoistableRead, in the given tensor uses, accessing the same
98 /// offset as the HoistableWrite.
99 static HoistableRead findMatchingTransferRead(HoistableWrite write,
100                                               Value srcTensor) {
101   assert(write.transferWriteOp &&
102          "expected hoistable write to have a .transfer_write");
103 
104   LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
105                     << *write.transferWriteOp.getOperation() << "\n");
106   if (write.insertSliceOp)
107     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: "
108                       << *write.insertSliceOp.getOperation() << "\n");
109 
110   for (Operation *user : srcTensor.getUsers()) {
111     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
112                       << "\n");
113 
114     // If HoistableWrite involves a InsertSliceOp, we need to find a
115     // matching ExtractSliceOp.
116     tensor::ExtractSliceOp sliceOp;
117     Operation *maybeTransferReadUser = user;
118     if (write.insertSliceOp) {
119       sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
120       if (!sliceOp || sliceOp.getResult().getType() !=
121                           write.insertSliceOp.source().getType())
122         continue;
123 
124       LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
125                         << *sliceOp << " vs " << *write.insertSliceOp << "\n");
126       if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp))
127         continue;
128 
129       LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
130       // If we got here, sliceOp is hoistable iff it has exactly 2 uses:
131       //   1. the transfer_write we want to hoist.
132       //   2. a matching transfer_read.
133       // Anything else, we skip.
134       bool skip = false;
135       Operation *otherUser = nullptr;
136       for (Operation *u : sliceOp->getUsers()) {
137         if (u == write.transferWriteOp)
138           continue;
139         if (otherUser) {
140           skip = true;
141           break;
142         }
143         otherUser = u;
144       }
145       if (skip || !otherUser)
146         continue;
147       maybeTransferReadUser = otherUser;
148     }
149 
150     LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
151                       << "\n");
152     auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
153     if (read && read.indices() == write.transferWriteOp.indices() &&
154         read.getVectorType() == write.transferWriteOp.getVectorType())
155       return HoistableRead{read, sliceOp};
156   }
157   return HoistableRead();
158 }
159 
160 /// Check if the chunk of data inserted by the HoistableWrite are read by any
161 /// other op than the HoistableRead candidate.
162 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
163                                            HoistableRead candidateRead,
164                                            BlockArgument tensorArg) {
165   // Make sure none of the other uses read the part of the tensor modified
166   // by the transfer_write.
167   llvm::SmallVector<Value::use_range, 1> uses;
168   uses.push_back(tensorArg.getUses());
169   while (!uses.empty()) {
170     for (OpOperand &use : uses.pop_back_val()) {
171       Operation *user = use.getOwner();
172       // Skip the candidate use, only inspect the "other" uses.
173       if (user == candidateRead.transferReadOp ||
174           user == candidateRead.extractSliceOp ||
175           user == write.transferWriteOp || user == write.insertSliceOp)
176         continue;
177       // Consider all transitive uses through a extract_slice / insert_slice.
178       // TODO: atm we just bail because a stronger analysis is needed for these
179       // cases.
180       if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
181         return true;
182       // Consider all transitive uses through a vector.transfer_write.
183       if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
184         uses.push_back(writeUser->getResult(0).getUses());
185         continue;
186       }
187       // Consider all nested uses through an scf::ForOp. We may have
188       // pass-through tensor arguments left from previous level of
189       // hoisting.
190       if (auto forUser = dyn_cast<scf::ForOp>(user)) {
191         Value arg = forUser.getLoopBody().getArgument(
192             use.getOperandNumber() - forUser.getNumControlOperands() +
193             /*iv value*/ 1);
194         uses.push_back(arg.getUses());
195         continue;
196       }
197       // Follow the use yield as long as it doesn't escape the original
198       // region.
199       scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
200       if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
201                            yieldUser->getParentOp())) {
202         Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
203         uses.push_back(ret.getUses());
204         continue;
205       }
206       auto read = dyn_cast<vector::TransferReadOp>(user);
207       if (!read || !isDisjointTransferIndices(
208                        cast<VectorTransferOpInterface>(read.getOperation()),
209                        cast<VectorTransferOpInterface>(
210                            write.transferWriteOp.getOperation()))) {
211         return true;
212       }
213     }
214   }
215   return false;
216 }
217 
218 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
219 /// Return the null HoistableWrite() if it is not comprised of a
220 /// vector.transfer_write + optional insert_slice or if any of the indexings
221 /// is `forOp`-dependent.
222 static HoistableWrite
223 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
224                                         OpOperand &yieldOperand) {
225   Value v = yieldOperand.get();
226   if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
227     // Indexing must not depend on `forOp`.
228     for (Value operand : write.indices())
229       if (!forOp.isDefinedOutsideOfLoop(operand))
230         return HoistableWrite();
231 
232     return HoistableWrite{write, nullptr};
233   }
234 
235   if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) {
236     // Inserted slice must come from vector.transfer_write.
237     auto write =
238         insertSliceOp.source().getDefiningOp<vector::TransferWriteOp>();
239     if (!write)
240       return HoistableWrite();
241 
242     // Tensor inserted into must be a BBArg at position matching yieldOperand's.
243     auto bbArg = insertSliceOp.dest().dyn_cast<BlockArgument>();
244     if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
245         bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
246       return HoistableWrite();
247 
248     // Indexing inserted into must not depend on `forOp`.
249     for (Value operand : insertSliceOp->getOperands().drop_front(
250              tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
251       if (!forOp.isDefinedOutsideOfLoop(operand))
252         return HoistableWrite();
253 
254     return HoistableWrite{write, insertSliceOp};
255   }
256 
257   return HoistableWrite();
258 }
259 
260 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
261 static void hoistReadWrite(HoistableRead read, HoistableWrite write,
262                            BlockArgument tensorBBArg) {
263   scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
264   assert(read.transferReadOp && write.transferWriteOp &&
265          "expected transfer_read and transfer_write ops to be set");
266   assert(((read.extractSliceOp && write.insertSliceOp) ||
267           (!read.extractSliceOp && !write.insertSliceOp)) &&
268          "expected matching extract_slice / insert_slice");
269   LLVM_DEBUG(DBGS() << "In forOp:\n"
270                     << *forOp.getOperation()
271                     << "\nHoist: " << *read.transferReadOp.getOperation()
272                     << "\nHoist: " << *write.transferWriteOp.getOperation()
273                     << "\nInvolving: " << tensorBBArg << "\n");
274 
275   // If a read slice is present, hoist it.
276   if (read.extractSliceOp && failed(forOp.moveOutOfLoop({read.extractSliceOp})))
277     llvm_unreachable("Unexpected failure moving extract_slice out of loop");
278 
279   // Hoist the transfer_read op.
280   if (failed(forOp.moveOutOfLoop({read.transferReadOp})))
281     llvm_unreachable("Unexpected failure moving transfer read out of loop");
282 
283   // TODO: don't hardcode /*numIvs=*/1.
284   assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
285   unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
286 
287   // Update the source tensor.
288   if (read.extractSliceOp)
289     read.extractSliceOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
290   else
291     read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
292 
293   // Hoist write after.
294   if (write.insertSliceOp)
295     write.insertSliceOp->moveAfter(forOp);
296   write.transferWriteOp->moveAfter(forOp);
297 
298   // Update the yield.
299   auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
300   if (write.insertSliceOp)
301     yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest());
302   else
303     yieldOp->setOperand(initArgNumber, write.transferWriteOp.source());
304 
305   // Rewrite `loop` with additional new yields.
306   OpBuilder b(read.transferReadOp);
307   auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(),
308                                      write.transferWriteOp.vector());
309   // Transfer write has been hoisted, need to update the vector and tensor
310   // source. Replace the result of the loop to use the new tensor created
311   // outside the loop.
312   // Depending on whether a insert_slice is present or not, it carries the
313   // update on the tensor operands.
314   if (write.insertSliceOp) {
315     newForOp.getResult(initArgNumber)
316         .replaceAllUsesWith(write.insertSliceOp.getResult());
317     write.transferWriteOp.sourceMutable().assign(read.extractSliceOp.result());
318     write.insertSliceOp.destMutable().assign(read.extractSliceOp.source());
319   } else {
320     newForOp.getResult(initArgNumber)
321         .replaceAllUsesWith(write.transferWriteOp.getResult(0));
322     write.transferWriteOp.sourceMutable().assign(
323         newForOp.getResult(initArgNumber));
324   }
325 
326   // Always update with the newly yield tensor and vector.
327   write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back());
328 }
329 
330 // To hoist transfer op on tensor the logic can be significantly simplified
331 // compared to the case on buffer. The transformation follows this logic:
332 // 1. Look for transfer_write with a single use from ForOp yield
333 // 2. Check the uses of the matching block argument and look for a transfer_read
334 // with the same indices.
335 // 3. Check that all the other uses of the tensor argument are either disjoint
336 // tensor_read or transfer_write. For transfer_write uses recurse to make sure
337 // the new tensor has the same restrictions on its uses.
338 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
339 // After this transformation the scf.forOp may have unused arguments that can be
340 // remove by the canonicalization pass.
341 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
342   bool changed = true;
343   while (changed) {
344     changed = false;
345     func.walk([&](scf::ForOp forOp) {
346       Operation *yield = forOp.getBody()->getTerminator();
347       for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
348         OpOperand &ret = yield->getOpOperand(it.index());
349         HoistableWrite write =
350             getLoopInvariantTransferWriteOpDefining(forOp, ret);
351         if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
352           continue;
353         LLVM_DEBUG(dbgs() << "\n";
354                    DBGS() << "Candidate write for hoisting: "
355                           << *write.transferWriteOp.getOperation() << "\n");
356         if (write.insertSliceOp)
357           LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: "
358                             << *write.insertSliceOp.getOperation() << "\n");
359         if (llvm::any_of(write.transferWriteOp.indices(),
360                          [&forOp](Value index) {
361                            return !forOp.isDefinedOutsideOfLoop(index);
362                          }))
363           continue;
364         // Find a read with the same type and indices.
365         HoistableRead matchingRead =
366             findMatchingTransferRead(write, it.value());
367         // Make sure none of the other uses read the part of the tensor modified
368         // by the transfer_write.
369         if (!matchingRead.transferReadOp ||
370             tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
371           continue;
372 
373         LLVM_DEBUG(DBGS() << "Start hoisting\n");
374         hoistReadWrite(matchingRead, write, it.value());
375         changed = true;
376         forOp.erase();
377 
378         // Need to interrupt and restart: erasing the loop messes up the walk.
379         return WalkResult::interrupt();
380       }
381       return WalkResult::advance();
382     });
383     // Apply canonicalization so the newForOp + yield folds immediately, thus
384     // cleaning up the IR and potentially enabling more hoisting.
385     if (changed) {
386       RewritePatternSet patterns(func->getContext());
387       scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
388       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
389     }
390   }
391 }
392 
393 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
394   bool changed = true;
395   while (changed) {
396     changed = false;
397 
398     func.walk([&](vector::TransferReadOp transferRead) {
399       if (!transferRead.getShapedType().isa<MemRefType>())
400         return WalkResult::advance();
401 
402       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
403                         << *transferRead.getOperation() << "\n");
404       auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
405       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
406                         << "\n");
407       if (!loop)
408         return WalkResult::advance();
409 
410       if (failed(moveLoopInvariantCode(
411               cast<LoopLikeOpInterface>(loop.getOperation()))))
412         llvm_unreachable(
413             "Unexpected failure to move invariant code out of loop");
414 
415       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
416                         << "\n");
417 
418       SetVector<Operation *> forwardSlice;
419       getForwardSlice(transferRead.getOperation(), &forwardSlice);
420 
421       // Look for the last TransferWriteOp in the forwardSlice of
422       // `transferRead` that operates on the same memref.
423       vector::TransferWriteOp transferWrite;
424       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
425         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
426         if (!candidateWrite || candidateWrite.source() != transferRead.source())
427           continue;
428         transferWrite = candidateWrite;
429       }
430 
431       // All operands of the TransferRead must be defined outside of the loop.
432       for (auto operand : transferRead.getOperands())
433         if (!loop.isDefinedOutsideOfLoop(operand))
434           return WalkResult::advance();
435 
436       // Only hoist transfer_read / transfer_write pairs for now.
437       if (!transferWrite)
438         return WalkResult::advance();
439 
440       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
441                         << "\n");
442 
443       // Approximate aliasing by checking that:
444       //   1. indices are the same,
445       //   2. no other operations in the loop access the same memref except
446       //      for transfer_read/transfer_write accessing statically disjoint
447       //      slices.
448       if (transferRead.indices() != transferWrite.indices() &&
449           transferRead.getVectorType() == transferWrite.getVectorType())
450         return WalkResult::advance();
451 
452       // TODO: may want to memoize this information for performance but it
453       // likely gets invalidated often.
454       DominanceInfo dom(loop);
455       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
456         return WalkResult::advance();
457       for (auto &use : transferRead.source().getUses()) {
458         if (!dom.properlyDominates(loop, use.getOwner()))
459           continue;
460         if (use.getOwner() == transferRead.getOperation() ||
461             use.getOwner() == transferWrite.getOperation())
462           continue;
463         if (auto transferWriteUse =
464                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
465           if (!isDisjointTransferSet(
466                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
467                   cast<VectorTransferOpInterface>(
468                       transferWriteUse.getOperation())))
469             return WalkResult::advance();
470         } else if (auto transferReadUse =
471                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
472           if (!isDisjointTransferSet(
473                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
474                   cast<VectorTransferOpInterface>(
475                       transferReadUse.getOperation())))
476             return WalkResult::advance();
477         } else {
478           // Unknown use, we cannot prove that it doesn't alias with the
479           // transferRead/transferWrite operations.
480           return WalkResult::advance();
481         }
482       }
483 
484       // Hoist read before.
485       if (failed(loop.moveOutOfLoop({transferRead})))
486         llvm_unreachable(
487             "Unexpected failure to move transfer read out of loop");
488 
489       // Hoist write after.
490       transferWrite->moveAfter(loop);
491 
492       // Rewrite `loop` with new yields by cloning and erase the original loop.
493       OpBuilder b(transferRead);
494       auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
495                                          transferWrite.vector());
496 
497       // Transfer write has been hoisted, need to update the written value to
498       // the value yielded by the newForOp.
499       transferWrite.vector().replaceAllUsesWith(
500           newForOp.getResults().take_back()[0]);
501 
502       changed = true;
503       loop.erase();
504       // Need to interrupt and restart because erasing the loop messes up the
505       // walk.
506       return WalkResult::interrupt();
507     });
508   }
509 }
510 
511 /// Return success if `v` is a value that is only transitively defined by ops of
512 /// type in `OpTypeList`.
513 template <typename... OpTypeList>
514 static bool backwardsSliceOnlyHasOpsOfType(scf::ForOp outerLimit, Value v) {
515   // Compute a backward slice up to, but not including, `outerLimit`.
516   SetVector<Operation *> backwardSlice;
517   getBackwardSlice(v, &backwardSlice, [&](Operation *op) {
518     return outerLimit->isProperAncestor(op);
519   });
520   // Traverse the backward slice and ensure we can perform the computation to
521   // hoist.
522   for (Operation *op : backwardSlice) {
523     if (isa<OpTypeList...>(op))
524       continue;
525     LLVM_DEBUG(DBGS() << "Abort: unadmissible op in slice " << *op << "\n");
526     return false;
527   }
528   return true;
529 }
530 
531 bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
532   return outer.isDefinedOutsideOfLoop(v) || v.getDefiningOp<ConstantOp>();
533 }
534 
535 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
536 /// The returned Value is guaranteed not to depend on any loop comprised in
537 /// [`outer`, `forOp`].
538 /// Return null if such a loop-independent quantity cannot be computed.
539 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer,
540                                      scf::ForOp forOp) {
541   MLIRContext *ctx = forOp->getContext();
542   AffineExpr iv, lb, step;
543   bindDims(ctx, iv, lb);
544   bindSymbols(ctx, step);
545   if (!isDefinedOutsideOrConstant(outer, forOp.lowerBound()) ||
546       !isDefinedOutsideOrConstant(outer, forOp.step()))
547     return Value();
548   Value ivVal = forOp.getInductionVar(), lbVal = forOp.lowerBound(),
549         stepVal = forOp.step();
550   auto loc = forOp->getLoc();
551   return b.createOrFold<AffineApplyOp>(loc, (iv - lb).ceilDiv(step),
552                                        ValueRange{ivVal, lbVal, stepVal});
553 }
554 
555 /// Given a set of loops, assumed to be scf::ForOp, create a constraint set
556 /// containing the inequalities `iv - lb >= 0` and `-iv + ub - 1 >= 0` for each
557 /// loop.
558 static ConstraintsSet initLoopIvsAndBounds(ArrayRef<Operation *> loops) {
559   ConstraintsSet constraints;
560   for (Operation *op : loops)
561     constraints.addDimId(constraints.getNumDimIds(),
562                          cast<scf::ForOp>(op).getInductionVar());
563   for (Operation *op : loops)
564     constraints.addDimId(constraints.getNumDimIds(),
565                          cast<scf::ForOp>(op).lowerBound());
566   for (Operation *op : loops)
567     constraints.addDimId(constraints.getNumDimIds(),
568                          cast<scf::ForOp>(op).upperBound());
569   unsigned numLoops = loops.size();
570   for (unsigned ivIdx = 0, e = numLoops; ivIdx < e; ++ivIdx) {
571     // iv - lb >= 0
572     SmallVector<int64_t, 8> ineqLb(constraints.getNumCols(), 0);
573     ineqLb[ivIdx] = 1;
574     ineqLb[ivIdx + numLoops] = -1;
575     // -iv + ub >= 0
576     SmallVector<int64_t, 8> ineqUb(constraints.getNumCols(), 0);
577     ineqUb[ivIdx] = -1;
578     ineqUb[ivIdx + 2 * numLoops] = 1;
579     ineqUb[constraints.getNumCols() - 1] = -1;
580     constraints.addInequality(ineqLb);
581     constraints.addInequality(ineqUb);
582   }
583   return constraints;
584 }
585 
586 /// For each loop in `loops`, determine the ops involved in the construction of
587 /// its upper bound---up to the outerLimit loop--- and fold them as new
588 /// inequalities in the constraint set.
589 /// This is achieved by computing the backwardSlice of the loop's upper bound
590 /// and iteratively folding each op in reverse topological order to guarantee
591 /// use-def ordering.
592 /// As operations are folded in, their result is projected out of the
593 /// constraints set.
594 /// The following operations are supported:
595 ///   - scf::ForOp are simply skipped.
596 ///   - AffineApplyOp are composed to replace the result by an equality.
597 ///   - AffineMinOp are composed by adding each entry as an upper bound.
598 /// If any other operation is met, return failure.
599 // TODO: extend on a per-need basis.
600 static LogicalResult
601 foldUpperBoundsIntoConstraintsSet(ConstraintsSet &constraints,
602                                   scf::ForOp outerLimit,
603                                   ArrayRef<Operation *> loops) {
604   SetVector<Value> toProjectOut;
605   for (Operation *loop : loops) {
606     auto ub = cast<scf::ForOp>(loop).upperBound();
607     if (isDefinedOutsideOrConstant(outerLimit, ub))
608       continue;
609 
610     // Compute a backward slice up to, but not including, `outerLimit`.
611     SetVector<Operation *> backwardSlice;
612     getBackwardSlice(ub, &backwardSlice, [&](Operation *op) {
613       return outerLimit->isProperAncestor(op);
614     });
615     backwardSlice.insert(ub.getDefiningOp());
616 
617     // Iterate over all ops in the slice and compose them in the constraints.
618     for (Operation *op : llvm::reverse(backwardSlice)) {
619       if (!isa<scf::ForOp, AffineApplyOp, AffineMinOp>(op))
620         return failure();
621       if (isa<scf::ForOp>(op))
622         continue;
623       // Ensure there is a
624       auto ensureIdFailed = [&](Value v) {
625         return failed(constraints.ensureIdOfType(v, /*asDim=*/true));
626       };
627 
628       // Ensure all ids exist and add results for later projection.
629       if (llvm::any_of(op->getResults(), ensureIdFailed) ||
630           llvm::any_of(op->getOperands(), ensureIdFailed))
631         return failure();
632 
633       // All supported ops have 1 result.
634       // TODO: extend when needed.
635       toProjectOut.insert(op->getResult(0));
636 
637       // Compose supported ops.
638       if (auto affineApplyOp = dyn_cast<AffineApplyOp>(op)) {
639         if (failed(constraints.composeAffineApply(affineApplyOp.getResult(),
640                                                   affineApplyOp.getAffineMap(),
641                                                   affineApplyOp.getOperands())))
642           return failure();
643         continue;
644       }
645       auto affineMinOp = cast<AffineMinOp>(op);
646       if (failed(constraints.composeMin(affineMinOp.getResult(),
647                                         affineMinOp.getAffineMap(),
648                                         affineMinOp.operands())))
649         return failure();
650     }
651   }
652   for (Value v : toProjectOut)
653     constraints.projectOut(v);
654   return success();
655 }
656 
657 /// Compute dynamic tensor sizes, independent of any value defined inside
658 /// `outer` and such that every n-D iteration of the packingLoops has its own
659 /// space (so that each packed buffer has a storage location). This is achieved
660 /// by computing the extent for each of the packing loops.
661 static LogicalResult computeBounds(scf::ForOp outer,
662                                    ArrayRef<Operation *> packingLoops,
663                                    SmallVector<AffineMap> &lbs,
664                                    SmallVector<AffineMap> &ubs) {
665   // Packing loop IVs are introduced as the first positions.
666   ConstraintsSet constraints = initLoopIvsAndBounds(packingLoops);
667   if (failed(
668           foldUpperBoundsIntoConstraintsSet(constraints, outer, packingLoops)))
669     return failure();
670   // Compute the bounds of the first positions, assuming the others are fixed.
671   constraints.getSliceBounds(/*pos=*/0, /*num=*/packingLoops.size(),
672                              outer->getContext(), &lbs, &ubs);
673   return success();
674 }
675 
676 /// Ensure prerequisites that guarantee pad op hoisting can occur.
677 /// Return failure in the cases when we cannot perform hoisting; i.e. if either:
678 ///   1. There exists a use of `padTensorOp` that is not a linalg input operand.
679 ///   2. There isn't an enclosing `outermostEnclosingForOp` loop.
680 ///   3. There exists an op with a region that is dominated by
681 ///   `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a
682 ///    LinalgOp.
683 ///   4. There exists an op with side effects that is dominated by
684 ///   `outermostEnclosingForOp` and that isn't a LoopLikeInterface.
685 ///   5. The lower bound, upper bound and step of all the loops involved in the
686 ///   hoisting can be
687 ///
688 /// While ensuring prerequisites:
689 ///   1. Fill the `backwardSlice` to contain the topologically sorted ops
690 ///   dominated by `outermostEnclosingForOp`.
691 ///   2. Fill the `packingLoops` to contain only the enclosing loops of
692 ///   `backwardSlice` whose IV is actually used in computing padding. Loops that
693 ///   remain in `backwardSlice` but that are not in `packingLoops` are
694 ///   dimensions of reuse.
695 static LogicalResult
696 hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels,
697                                    SetVector<Operation *> &backwardSlice,
698                                    SetVector<Operation *> &packingLoops,
699                                    SmallVector<Value> &dynamicTensorSizes) {
700   // Bail on any use that isn't an input of a Linalg op.
701   // Hoisting of inplace updates happens after vectorization.
702   for (OpOperand &use : padTensorOp.result().getUses()) {
703     auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
704     if (!linalgUser || !linalgUser.isInputTensor(&use))
705       return failure();
706   }
707 
708   // Get at most nLevels of enclosing loops.
709   SmallVector<LoopLikeOpInterface> reverseEnclosingLoops;
710   Operation *outermostEnclosingForOp = nullptr,
711             *nextEnclosingForOp =
712                 padTensorOp->getParentOfType<LoopLikeOpInterface>();
713   while (nLevels-- > 0 && nextEnclosingForOp) {
714     outermostEnclosingForOp = nextEnclosingForOp;
715     reverseEnclosingLoops.push_back(outermostEnclosingForOp);
716     nextEnclosingForOp =
717         nextEnclosingForOp->getParentOfType<LoopLikeOpInterface>();
718   }
719   if (!outermostEnclosingForOp)
720     return failure();
721 
722   // Get the backwards slice from `padTensorOp` that is dominated by the
723   // outermost enclosing loop.
724   DominanceInfo domInfo(outermostEnclosingForOp);
725   getBackwardSlice(padTensorOp.getOperation(), &backwardSlice,
726                    [&](Operation *op) {
727                      return domInfo.dominates(outermostEnclosingForOp, op);
728                    });
729 
730   // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp.
731   if (llvm::any_of(backwardSlice, [](Operation *op) {
732         return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) &&
733                !isa<LinalgOp>(op);
734       }))
735     return failure();
736 
737   // Filter out the loops whose induction variable is not used to compute the
738   // padded result. As a first approximation, just look for IVs that have no use
739   // in the backwardSlice.
740   // These are the dimensions of reuse that we can exploit to reduce the amount
741   // of work / memory.
742   // TODO: would this optimization compose better as a canonicalization?
743   for (LoopLikeOpInterface loop : llvm::reverse(reverseEnclosingLoops)) {
744     auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
745     if (!forOp)
746       continue;
747     for (Operation *user : forOp.getInductionVar().getUsers()) {
748       if (backwardSlice.contains(user)) {
749         packingLoops.insert(forOp);
750         break;
751       }
752     }
753   }
754 
755   // Backward slice is a topologically sorted list of ops starting at
756   // `outermostEnclosingForOp`.
757   assert(outermostEnclosingForOp == backwardSlice.front());
758 
759   scf::ForOp outer = cast<scf::ForOp>(outermostEnclosingForOp);
760 
761   ConstraintsSet constraints = initLoopIvsAndBounds(packingLoops.getArrayRef());
762   if (failed(foldUpperBoundsIntoConstraintsSet(constraints, outer,
763                                                packingLoops.getArrayRef())))
764     return failure();
765 
766   unsigned numLoops = packingLoops.size();
767   SmallVector<AffineMap> lbs(numLoops), ubs(numLoops);
768   if (failed(computeBounds(outer, packingLoops.getArrayRef(), lbs, ubs)))
769     return failure();
770 
771   SmallVector<Value> allValues;
772   constraints.getAllValues(&allValues);
773   SmallVector<Value> allNonLoopValues(allValues.begin() + numLoops,
774                                       allValues.end());
775 
776   // For each packingLoop, create the extent by (ub - lb).ceilDiv(step).
777   // IP just before the outermost loop considered that we hoist above.
778   ImplicitLocOpBuilder b(outer->getLoc(), outer);
779   assert(packingLoops.size() == lbs.size() && "expected matching lb sizes");
780   assert(packingLoops.size() == ubs.size() && "expected matching ub sizes");
781   for (auto it : llvm::zip(packingLoops, lbs, ubs)) {
782     scf::ForOp loop = cast<scf::ForOp>(std::get<0>(it));
783     AffineMap lbMap = std::get<1>(it);
784     AffineMap ubMap = std::get<2>(it);
785     SmallVector<Value> lbOperands(allNonLoopValues);
786     canonicalizeMapAndOperands(&lbMap, &lbOperands);
787     Value lbVal = b.createOrFold<AffineMaxOp>(lbMap, lbOperands);
788 
789     SmallVector<Value> ubOperands(allNonLoopValues);
790     canonicalizeMapAndOperands(&ubMap, &ubOperands);
791     Value ubVal = b.createOrFold<AffineMinOp>(ubMap, ubOperands);
792 
793     AffineExpr lb, ub, step;
794     bindDims(b.getContext(), lb, ub);
795     bindSymbols(b.getContext(), step);
796     Value res = b.createOrFold<AffineApplyOp>(
797         (ub - lb).ceilDiv(step),
798         ValueRange{lbVal, ubVal, cast<scf::ForOp>(loop).step()});
799 
800     dynamicTensorSizes.push_back(res);
801   }
802 
803   return success();
804 }
805 
806 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
807                                                   unsigned nLoops) {
808   SmallVector<Value> dynamicTensorSizes;
809   SetVector<Operation *> backwardSlice, packingLoops;
810   if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops,
811                                                 backwardSlice, packingLoops,
812                                                 dynamicTensorSizes)))
813     return failure();
814 
815   // Update actual number of loops, which may be smaller.
816   nLoops = packingLoops.size();
817 
818   Location loc = padTensorOp->getLoc();
819   RankedTensorType paddedTensorType = padTensorOp.getResultType();
820   unsigned paddedRank = paddedTensorType.getRank();
821 
822   // Backward slice is a topologically sorted list of ops starting at
823   // `outermostEnclosingForOp`.
824   Operation *outermostEnclosingForOp = backwardSlice.front();
825   // IP just before the outermost loop considered that we hoist above.
826   OpBuilder b(outermostEnclosingForOp);
827 
828   // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
829   // padding.
830   SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize);
831   // TODO: go grab dims when necessary, for now PadTensorOp returns a static
832   // tensor.
833   llvm::append_range(packedShape, paddedTensorType.getShape());
834   auto packedTensorType =
835       RankedTensorType::get(packedShape, paddedTensorType.getElementType());
836   Value packedTensor = b.create<linalg::InitTensorOp>(
837       loc, dynamicTensorSizes, packedTensorType.getShape(),
838       packedTensorType.getElementType());
839 
840   // Clone the operations involved in the backward slice, iteratively stepping
841   // into the loops that we encounter.
842   // The implementation proceeds in a stack-like fashion:
843   //   1. Iteratively clone and step into the loops, pushing the `packedTensor`
844   //      deeper in the stack.
845   //   2. Create a InsertSliceOp at the top of the stack.
846   //   3. Iteratively pop and yield the result of the InsertSliceOp across
847   //     the cloned loops.
848   SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
849   clonedLoopIvs.reserve(nLoops);
850   leadingPackedTensorIndexings.reserve(nLoops);
851   BlockAndValueMapping bvm;
852   // Insert `padTensorOp` into the backwardSlice so we clone it too.
853   backwardSlice.insert(padTensorOp);
854   // Stack step 1. iteratively clone loops and push `packedTensor`.
855   for (Operation *op : backwardSlice) {
856     // Specifically sit out in the extract_slice(packedTensor) case: this is the
857     // piece we seek to replace.
858     if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
859       if (bvm.lookupOrDefault(sliceOp.source()) == packedTensor)
860         continue;
861     auto effects = dyn_cast<MemoryEffectOpInterface>(op);
862     bool hasNoEffects = !effects || effects.hasNoEffect();
863     if (hasNoEffects &&
864         (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op))) {
865       b.clone(*op, bvm);
866       continue;
867     }
868     // TODO: support more cases as they appear.
869     auto forOp = dyn_cast<scf::ForOp>(op);
870     assert(forOp && "Expected scf::ForOp when hoisting pad ops");
871     // Unused loop, just skip it.
872     if (!packingLoops.contains(forOp))
873       continue;
874 
875     auto clonedForOp =
876         b.create<scf::ForOp>(loc, bvm.lookupOrDefault(forOp.lowerBound()),
877                              bvm.lookupOrDefault(forOp.upperBound()),
878                              bvm.lookupOrDefault(forOp.step()), packedTensor);
879     // Map the induction var, region args and results to the `clonedForOp`.
880     bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
881     bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());
882     bvm.map(forOp.getResults(), clonedForOp.getResults());
883     assert(clonedForOp->getNumRegions() == 1);
884     clonedLoopIvs.push_back(clonedForOp.getInductionVar());
885 
886     b.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
887     Value loopIndependentIterationCount = buildLoopIterationCount(
888         b, cast<scf::ForOp>(outermostEnclosingForOp), clonedForOp);
889     // Assert the loop-independent iteration count can be computed.
890     if (!loopIndependentIterationCount)
891       llvm_unreachable("loop independence prerequisite not met");
892     leadingPackedTensorIndexings.push_back(loopIndependentIterationCount);
893     packedTensor = clonedForOp.getRegionIterArgs().front();
894   }
895 
896   // Stack step 2. create InsertSliceOp at the top of the stack.
897   // offsets = [clonedLoopIvs, 0 .. 0].
898   SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(),
899                                     leadingPackedTensorIndexings.end());
900   offsets.append(paddedRank, b.getIndexAttr(0));
901   // sizes = [1 .. 1, paddedShape].
902   SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1));
903   for (int64_t sz : paddedTensorType.getShape()) {
904     // TODO: go grab dims when necessary, for now PadTensorOp returns a static
905     // tensor.
906     assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes");
907     sizes.push_back(b.getIndexAttr(sz));
908   }
909   // strides = [1 .. 1].
910   SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1));
911 
912   Value inserted =
913       b.create<tensor::InsertSliceOp>(loc, bvm.lookup(padTensorOp.result()),
914                                       packedTensor, offsets, sizes, strides);
915 
916   // Stack step 3. iteratively pop the stack and propagate the yield.
917   Value valueToYield = inserted;
918   for (Value iv : llvm::reverse(clonedLoopIvs)) {
919     auto forOp = scf::getForInductionVarOwner(iv);
920     b.setInsertionPointToEnd(&forOp.getRegion().front());
921     b.create<scf::YieldOp>(loc, valueToYield);
922     valueToYield = forOp.getResult(0);
923   }
924 
925   // Now the packed tensor is ready, replace the original padding op by a
926   // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
927   b.setInsertionPoint(padTensorOp);
928   SmallVector<Value> loopIterationCounts =
929       llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {
930         return buildLoopIterationCount(
931             b, cast<scf::ForOp>(outermostEnclosingForOp),
932             cast<scf::ForOp>(loop));
933       }));
934   // Assert all loop iteration counts can be computed.
935   if (llvm::any_of(loopIterationCounts, [](Value v) { return !v; }))
936     llvm_unreachable("loop independence prerequisite not met");
937   // offsets = [originalLoopIvs, 0 .. 0].
938   offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end());
939   offsets.append(paddedRank, b.getIndexAttr(0));
940   // sizes = [1 .. 1, paddedShape] (definedabove).
941   // strides = [1 .. 1] (defined above)
942   packedTensor =
943       scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
944   padTensorOp.replaceAllUsesWith(
945       b.create<tensor::ExtractSliceOp>(loc, padTensorOp.getResultType(),
946                                        packedTensor, offsets, sizes, strides)
947           ->getResult(0));
948 
949   Operation *toErase = padTensorOp;
950 
951   // Make the newly cloned `padTensorOp` available to the caller.
952   padTensorOp =
953       cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp());
954 
955   toErase->erase();
956 
957   return success();
958 }
959