1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Utils.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
19 #include "mlir/Dialect/SCF/SCF.h"
20 #include "mlir/Dialect/SCF/Utils.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Vector/VectorOps.h"
24 #include "mlir/Dialect/Vector/VectorUtils.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Dominance.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "mlir/Transforms/LoopUtils.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Support/Debug.h"
31 
32 using llvm::dbgs;
33 
34 #define DEBUG_TYPE "linalg-hoisting"
35 
36 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
37 
38 using namespace mlir;
39 using namespace mlir::linalg;
40 
41 namespace {
42 /// Represents a unit of hoistable TransferWriteOp. This may comprise other
43 /// instructions that need to be hoisted too.
44 struct HoistableWrite {
45   vector::TransferWriteOp transferWriteOp;
46   tensor::InsertSliceOp insertSliceOp;
47 };
48 /// Represents a unit of hoistable TransferReadOp. This may comprise other
49 /// instructions that need to be hoisted too.
50 struct HoistableRead {
51   vector::TransferReadOp transferReadOp;
52   tensor::ExtractSliceOp extractSliceOp;
53 };
54 } // namespace
55 
56 /// Return true if op1 and op2 are the same constant or the same SSA value.
57 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
58   auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
59     Attribute attr = ofr.dyn_cast<Attribute>();
60     // Note: isa+cast-like pattern allows writing the condition below as 1 line.
61     if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
62       attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
63     if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
64       return intAttr.getValue().getSExtValue();
65     return llvm::None;
66   };
67   auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
68   if (cst1 && cst2 && *cst1 == *cst2)
69     return true;
70   auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
71   return v1 && v2 && v1 == v2;
72 }
73 
74 /// Return true is all offsets, sizes and strides are equal.
75 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s,
76                                        tensor::InsertSliceOp si) {
77   if (s.static_offsets().size() != si.static_offsets().size())
78     return false;
79   if (s.static_sizes().size() != si.static_sizes().size())
80     return false;
81   if (s.static_strides().size() != si.static_strides().size())
82     return false;
83   for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
84     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
85       return false;
86   for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
87     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
88       return false;
89   for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
90     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
91       return false;
92   return true;
93 }
94 
95 /// Look for a HoistableRead, in the given tensor uses, accessing the same
96 /// offset as the HoistableWrite.
97 static HoistableRead findMatchingTransferRead(HoistableWrite write,
98                                               Value srcTensor) {
99   assert(write.transferWriteOp &&
100          "expected hoistable write to have a .transfer_write");
101 
102   LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
103                     << *write.transferWriteOp.getOperation() << "\n");
104   if (write.insertSliceOp)
105     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: "
106                       << *write.insertSliceOp.getOperation() << "\n");
107 
108   for (Operation *user : srcTensor.getUsers()) {
109     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
110                       << "\n");
111 
112     // If HoistableWrite involves a InsertSliceOp, we need to find a
113     // matching ExtractSliceOp.
114     tensor::ExtractSliceOp sliceOp;
115     Operation *maybeTransferReadUser = user;
116     if (write.insertSliceOp) {
117       sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
118       if (!sliceOp || sliceOp.getResult().getType() !=
119                           write.insertSliceOp.source().getType())
120         continue;
121 
122       LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
123                         << *sliceOp << " vs " << *write.insertSliceOp << "\n");
124       if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp))
125         continue;
126 
127       LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
128       // If we got here, sliceOp is hoistable iff it has exactly 2 uses:
129       //   1. the transfer_write we want to hoist.
130       //   2. a matching transfer_read.
131       // Anything else, we skip.
132       bool skip = false;
133       Operation *otherUser = nullptr;
134       for (Operation *u : sliceOp->getUsers()) {
135         if (u == write.transferWriteOp)
136           continue;
137         if (otherUser) {
138           skip = true;
139           break;
140         }
141         otherUser = u;
142       }
143       if (skip || !otherUser)
144         continue;
145       maybeTransferReadUser = otherUser;
146     }
147 
148     LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
149                       << "\n");
150     auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
151     if (read && read.indices() == write.transferWriteOp.indices() &&
152         read.getVectorType() == write.transferWriteOp.getVectorType())
153       return HoistableRead{read, sliceOp};
154   }
155   return HoistableRead();
156 }
157 
158 /// Check if the chunk of data inserted by the HoistableWrite are read by any
159 /// other op than the HoistableRead candidate.
160 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
161                                            HoistableRead candidateRead,
162                                            BlockArgument tensorArg) {
163   // Make sure none of the other uses read the part of the tensor modified
164   // by the transfer_write.
165   llvm::SmallVector<Value::use_range, 1> uses;
166   uses.push_back(tensorArg.getUses());
167   while (!uses.empty()) {
168     for (OpOperand &use : uses.pop_back_val()) {
169       Operation *user = use.getOwner();
170       // Skip the candidate use, only inspect the "other" uses.
171       if (user == candidateRead.transferReadOp ||
172           user == candidateRead.extractSliceOp ||
173           user == write.transferWriteOp || user == write.insertSliceOp)
174         continue;
175       // Consider all transitive uses through a extract_slice / insert_slice.
176       // TODO: atm we just bail because a stronger analysis is needed for these
177       // cases.
178       if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
179         return true;
180       // Consider all transitive uses through a vector.transfer_write.
181       if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
182         uses.push_back(writeUser->getResult(0).getUses());
183         continue;
184       }
185       // Consider all nested uses through an scf::ForOp. We may have
186       // pass-through tensor arguments left from previous level of
187       // hoisting.
188       if (auto forUser = dyn_cast<scf::ForOp>(user)) {
189         Value arg = forUser.getLoopBody().getArgument(
190             use.getOperandNumber() - forUser.getNumControlOperands() +
191             /*iv value*/ 1);
192         uses.push_back(arg.getUses());
193         continue;
194       }
195       // Follow the use yield as long as it doesn't escape the original
196       // region.
197       scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
198       if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
199                            yieldUser->getParentOp())) {
200         Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
201         uses.push_back(ret.getUses());
202         continue;
203       }
204       auto read = dyn_cast<vector::TransferReadOp>(user);
205       if (!read || !isDisjointTransferIndices(
206                        cast<VectorTransferOpInterface>(read.getOperation()),
207                        cast<VectorTransferOpInterface>(
208                            write.transferWriteOp.getOperation()))) {
209         return true;
210       }
211     }
212   }
213   return false;
214 }
215 
216 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
217 /// Return the null HoistableWrite() if it is not comprised of a
218 /// vector.transfer_write + optional insert_slice or if any of the indexings
219 /// is `forOp`-dependent.
220 static HoistableWrite
221 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
222                                         OpOperand &yieldOperand) {
223   Value v = yieldOperand.get();
224   if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
225     // Indexing must not depend on `forOp`.
226     for (Value operand : write.indices())
227       if (!forOp.isDefinedOutsideOfLoop(operand))
228         return HoistableWrite();
229 
230     return HoistableWrite{write, nullptr};
231   }
232 
233   if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) {
234     // Inserted slice must come from vector.transfer_write.
235     auto write =
236         insertSliceOp.source().getDefiningOp<vector::TransferWriteOp>();
237     if (!write)
238       return HoistableWrite();
239 
240     // Tensor inserted into must be a BBArg at position matching yieldOperand's.
241     auto bbArg = insertSliceOp.dest().dyn_cast<BlockArgument>();
242     if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
243         bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
244       return HoistableWrite();
245 
246     // Indexing inserted into must not depend on `forOp`.
247     for (Value operand : insertSliceOp->getOperands().drop_front(
248              tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
249       if (!forOp.isDefinedOutsideOfLoop(operand))
250         return HoistableWrite();
251 
252     return HoistableWrite{write, insertSliceOp};
253   }
254 
255   return HoistableWrite();
256 }
257 
258 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
259 static void hoistReadWrite(HoistableRead read, HoistableWrite write,
260                            BlockArgument tensorBBArg) {
261   scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
262   assert(read.transferReadOp && write.transferWriteOp &&
263          "expected transfer_read and transfer_write ops to be set");
264   assert(((read.extractSliceOp && write.insertSliceOp) ||
265           (!read.extractSliceOp && !write.insertSliceOp)) &&
266          "expected matching extract_slice / insert_slice");
267   LLVM_DEBUG(DBGS() << "In forOp:\n"
268                     << *forOp.getOperation()
269                     << "\nHoist: " << *read.transferReadOp.getOperation()
270                     << "\nHoist: " << *write.transferWriteOp.getOperation()
271                     << "\nInvolving: " << tensorBBArg << "\n");
272 
273   // If a read slice is present, hoist it.
274   if (read.extractSliceOp && failed(forOp.moveOutOfLoop({read.extractSliceOp})))
275     llvm_unreachable("Unexpected failure moving extract_slice out of loop");
276 
277   // Hoist the transfer_read op.
278   if (failed(forOp.moveOutOfLoop({read.transferReadOp})))
279     llvm_unreachable("Unexpected failure moving transfer read out of loop");
280 
281   // TODO: don't hardcode /*numIvs=*/1.
282   assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
283   unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
284 
285   // Update the source tensor.
286   if (read.extractSliceOp)
287     read.extractSliceOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
288   else
289     read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
290 
291   // Hoist write after.
292   if (write.insertSliceOp)
293     write.insertSliceOp->moveAfter(forOp);
294   write.transferWriteOp->moveAfter(forOp);
295 
296   // Update the yield.
297   auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
298   if (write.insertSliceOp)
299     yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest());
300   else
301     yieldOp->setOperand(initArgNumber, write.transferWriteOp.source());
302 
303   // Rewrite `loop` with additional new yields.
304   OpBuilder b(read.transferReadOp);
305   auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(),
306                                      write.transferWriteOp.vector());
307   // Transfer write has been hoisted, need to update the vector and tensor
308   // source. Replace the result of the loop to use the new tensor created
309   // outside the loop.
310   // Depending on whether a insert_slice is present or not, it carries the
311   // update on the tensor operands.
312   if (write.insertSliceOp) {
313     newForOp.getResult(initArgNumber)
314         .replaceAllUsesWith(write.insertSliceOp.getResult());
315     write.transferWriteOp.sourceMutable().assign(read.extractSliceOp.result());
316     write.insertSliceOp.destMutable().assign(read.extractSliceOp.source());
317   } else {
318     newForOp.getResult(initArgNumber)
319         .replaceAllUsesWith(write.transferWriteOp.getResult(0));
320     write.transferWriteOp.sourceMutable().assign(
321         newForOp.getResult(initArgNumber));
322   }
323 
324   // Always update with the newly yield tensor and vector.
325   write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back());
326 }
327 
328 // To hoist transfer op on tensor the logic can be significantly simplified
329 // compared to the case on buffer. The transformation follows this logic:
330 // 1. Look for transfer_write with a single use from ForOp yield
331 // 2. Check the uses of the matching block argument and look for a transfer_read
332 // with the same indices.
333 // 3. Check that all the other uses of the tensor argument are either disjoint
334 // tensor_read or transfer_write. For transfer_write uses recurse to make sure
335 // the new tensor has the same restrictions on its uses.
336 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
337 // After this transformation the scf.forOp may have unused arguments that can be
338 // remove by the canonicalization pass.
339 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
340   bool changed = true;
341   while (changed) {
342     changed = false;
343     func.walk([&](scf::ForOp forOp) {
344       Operation *yield = forOp.getBody()->getTerminator();
345       for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
346         OpOperand &ret = yield->getOpOperand(it.index());
347         HoistableWrite write =
348             getLoopInvariantTransferWriteOpDefining(forOp, ret);
349         if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
350           continue;
351         LLVM_DEBUG(dbgs() << "\n";
352                    DBGS() << "Candidate write for hoisting: "
353                           << *write.transferWriteOp.getOperation() << "\n");
354         if (write.insertSliceOp)
355           LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: "
356                             << *write.insertSliceOp.getOperation() << "\n");
357         if (llvm::any_of(write.transferWriteOp.indices(),
358                          [&forOp](Value index) {
359                            return !forOp.isDefinedOutsideOfLoop(index);
360                          }))
361           continue;
362         // Find a read with the same type and indices.
363         HoistableRead matchingRead =
364             findMatchingTransferRead(write, it.value());
365         // Make sure none of the other uses read the part of the tensor modified
366         // by the transfer_write.
367         if (!matchingRead.transferReadOp ||
368             tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
369           continue;
370 
371         LLVM_DEBUG(DBGS() << "Start hoisting\n");
372         hoistReadWrite(matchingRead, write, it.value());
373         changed = true;
374         forOp.erase();
375 
376         // Need to interrupt and restart: erasing the loop messes up the walk.
377         return WalkResult::interrupt();
378       }
379       return WalkResult::advance();
380     });
381     // Apply canonicalization so the newForOp + yield folds immediately, thus
382     // cleaning up the IR and potentially enabling more hoisting.
383     if (changed) {
384       RewritePatternSet patterns(func->getContext());
385       scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
386       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
387     }
388   }
389 }
390 
391 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
392   bool changed = true;
393   while (changed) {
394     changed = false;
395 
396     func.walk([&](vector::TransferReadOp transferRead) {
397       if (!transferRead.getShapedType().isa<MemRefType>())
398         return WalkResult::advance();
399 
400       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
401                         << *transferRead.getOperation() << "\n");
402       auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
403       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
404                         << "\n");
405       if (!loop)
406         return WalkResult::advance();
407 
408       if (failed(moveLoopInvariantCode(
409               cast<LoopLikeOpInterface>(loop.getOperation()))))
410         llvm_unreachable(
411             "Unexpected failure to move invariant code out of loop");
412 
413       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
414                         << "\n");
415 
416       SetVector<Operation *> forwardSlice;
417       getForwardSlice(transferRead.getOperation(), &forwardSlice);
418 
419       // Look for the last TransferWriteOp in the forwardSlice of
420       // `transferRead` that operates on the same memref.
421       vector::TransferWriteOp transferWrite;
422       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
423         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
424         if (!candidateWrite || candidateWrite.source() != transferRead.source())
425           continue;
426         transferWrite = candidateWrite;
427       }
428 
429       // All operands of the TransferRead must be defined outside of the loop.
430       for (auto operand : transferRead.getOperands())
431         if (!loop.isDefinedOutsideOfLoop(operand))
432           return WalkResult::advance();
433 
434       // Only hoist transfer_read / transfer_write pairs for now.
435       if (!transferWrite)
436         return WalkResult::advance();
437 
438       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
439                         << "\n");
440 
441       // Approximate aliasing by checking that:
442       //   1. indices are the same,
443       //   2. no other operations in the loop access the same memref except
444       //      for transfer_read/transfer_write accessing statically disjoint
445       //      slices.
446       if (transferRead.indices() != transferWrite.indices() &&
447           transferRead.getVectorType() == transferWrite.getVectorType())
448         return WalkResult::advance();
449 
450       // TODO: may want to memoize this information for performance but it
451       // likely gets invalidated often.
452       DominanceInfo dom(loop);
453       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
454         return WalkResult::advance();
455       for (auto &use : transferRead.source().getUses()) {
456         if (!dom.properlyDominates(loop, use.getOwner()))
457           continue;
458         if (use.getOwner() == transferRead.getOperation() ||
459             use.getOwner() == transferWrite.getOperation())
460           continue;
461         if (auto transferWriteUse =
462                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
463           if (!isDisjointTransferSet(
464                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
465                   cast<VectorTransferOpInterface>(
466                       transferWriteUse.getOperation())))
467             return WalkResult::advance();
468         } else if (auto transferReadUse =
469                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
470           if (!isDisjointTransferSet(
471                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
472                   cast<VectorTransferOpInterface>(
473                       transferReadUse.getOperation())))
474             return WalkResult::advance();
475         } else {
476           // Unknown use, we cannot prove that it doesn't alias with the
477           // transferRead/transferWrite operations.
478           return WalkResult::advance();
479         }
480       }
481 
482       // Hoist read before.
483       if (failed(loop.moveOutOfLoop({transferRead})))
484         llvm_unreachable(
485             "Unexpected failure to move transfer read out of loop");
486 
487       // Hoist write after.
488       transferWrite->moveAfter(loop);
489 
490       // Rewrite `loop` with new yields by cloning and erase the original loop.
491       OpBuilder b(transferRead);
492       auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
493                                          transferWrite.vector());
494 
495       // Transfer write has been hoisted, need to update the written value to
496       // the value yielded by the newForOp.
497       transferWrite.vector().replaceAllUsesWith(
498           newForOp.getResults().take_back()[0]);
499 
500       changed = true;
501       loop.erase();
502       // Need to interrupt and restart because erasing the loop messes up the
503       // walk.
504       return WalkResult::interrupt();
505     });
506   }
507 }
508 
509 /// Return success if `v` is a value that is only transitively defined by ops of
510 /// type in `OpTypeList`.
511 template <typename... OpTypeList>
512 static bool backwardsSliceOnlyHasOpsOfType(scf::ForOp outerLimit, Value v) {
513   // Compute a backward slice up to, but not including, `outerLimit`.
514   SetVector<Operation *> backwardSlice;
515   getBackwardSlice(v, &backwardSlice, [&](Operation *op) {
516     return outerLimit->isProperAncestor(op);
517   });
518   // Traverse the backward slice and ensure we can perform the computation to
519   // hoist.
520   for (Operation *op : backwardSlice) {
521     if (isa<OpTypeList...>(op))
522       continue;
523     LLVM_DEBUG(DBGS() << "Abort: unadmissible op in slice " << *op << "\n");
524     return false;
525   }
526   return true;
527 }
528 
529 bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
530   return outer.isDefinedOutsideOfLoop(v) || v.getDefiningOp<ConstantOp>();
531 }
532 
533 /// Compute the tightest lower bound with quantities that are all defined
534 /// outside of `outer`.
535 /// Return null if such a bound cannot be computed.
536 Value computeLoopIndependentLowerBound(OpBuilder &b, scf::ForOp outer,
537                                        Value v) {
538   if (isDefinedOutsideOrConstant(outer, v))
539     return v;
540   return Value();
541 }
542 
543 /// Compute the tightest upper bound with quantities that are all defined
544 /// outside of `outer`.
545 /// Expects all ops in the backward slice of `v` up to `outer` to be either
546 /// scf.for, affine.min or affine.apply.
547 static Value computeLoopIndependentUpperBound(OpBuilder &b, scf::ForOp outer,
548                                               Value v) {
549   if (isDefinedOutsideOrConstant(outer, v))
550     return v;
551 
552   LLVM_DEBUG(DBGS() << "Begin loopIndependentUpperBound for: " << v << "\n");
553 
554   bool ok =
555       backwardsSliceOnlyHasOpsOfType<scf::ForOp, AffineMinOp, AffineApplyOp>(
556           outer, v);
557   assert(ok && "expected to only be defined by scf::ForOp and AffineMinOp");
558   (void)ok;
559 
560   // Compute a backward slice up to, but not including, `outer`.
561   SetVector<Operation *> backwardSlice;
562   getBackwardSlice(v, &backwardSlice,
563                    [&](Operation *op) { return outer->isProperAncestor(op); });
564   backwardSlice.insert(v.getDefiningOp());
565 
566   OpBuilder::InsertionGuard g(b);
567   b.setInsertionPoint(outer);
568   Value res = v;
569   BlockAndValueMapping bvm;
570   for (Operation *op : backwardSlice) {
571     if (isa<scf::ForOp>(op))
572       continue;
573     if (isa<AffineApplyOp>(op)) {
574       b.clone(*op, bvm);
575       continue;
576     }
577     auto sliceMinOp = cast<AffineMinOp>(op);
578     GetMinMaxExprFn getSCFMinMax = [&](Value value,
579                                        SmallVectorImpl<Value> &dims,
580                                        SmallVectorImpl<Value> &symbols) {
581       return getSCFMinMaxExpr(value, dims, symbols, [&](Operation *op) {
582         return outer->isAncestor(op);
583       });
584     };
585     // Perform the substitution of the operands of AffineMinOp.
586     auto mapAndOperands = substituteMin(sliceMinOp, getSCFMinMax);
587     SmallVector<Value> resultOperands = mapAndOperands.dims;
588     llvm::append_range(resultOperands, mapAndOperands.symbols);
589     AffineMap map = mapAndOperands.map;
590     canonicalizeMapAndOperands(&map, &resultOperands);
591     map = simplifyAffineMap(map);
592     res = b.create<AffineMinOp>(
593         outer->getLoc(), map,
594         llvm::to_vector<4>(llvm::map_range(resultOperands, [&](Value operand) {
595           return bvm.lookupOrDefault(operand);
596         })));
597     bvm.map(sliceMinOp, res);
598   }
599   LLVM_DEBUG(DBGS() << "End loopIndependentUpperBound with: " << res << "\n");
600   return res;
601 }
602 
603 /// Return the number of iterations in the loop (ub - lb).ceilDiv(step).
604 /// The returned Value is guaranteed not to depend on any loop comprised in
605 /// [`outer`, `forOp`].
606 /// Return null if such a loop-independent quantity cannot be computed.
607 static Value buildLoopTripCount(OpBuilder &b, scf::ForOp outer,
608                                 scf::ForOp forOp) {
609   MLIRContext *ctx = forOp->getContext();
610   AffineExpr lb, ub, step;
611   bindDims(ctx, lb, ub);
612   bindSymbols(ctx, step);
613   Value lbVal = computeLoopIndependentLowerBound(b, outer, forOp.lowerBound()),
614         ubVal = computeLoopIndependentUpperBound(b, outer, forOp.upperBound()),
615         stepVal = forOp.step();
616   if (!lbVal || !ubVal || !stepVal)
617     return Value();
618   auto loc = forOp->getLoc();
619   Value res = b.create<AffineApplyOp>(loc, (ub - lb).ceilDiv(step),
620                                       ValueRange{lbVal, ubVal, stepVal});
621   return res;
622 }
623 
624 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
625 /// The returned Value is guaranteed not to depend on any loop comprised in
626 /// [`outer`, `forOp`].
627 /// Return null if such a loop-independent quantity cannot be computed.
628 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer,
629                                      scf::ForOp forOp) {
630   MLIRContext *ctx = forOp->getContext();
631   AffineExpr iv, lb, step;
632   bindDims(ctx, iv, lb);
633   bindSymbols(ctx, step);
634   Value ivVal = forOp.getInductionVar(),
635         lbVal = computeLoopIndependentLowerBound(b, outer, forOp.lowerBound()),
636         stepVal = forOp.step();
637   if (!ivVal || !lbVal || !stepVal)
638     return Value();
639   auto loc = forOp->getLoc();
640   return b.create<AffineApplyOp>(loc, (iv - lb).ceilDiv(step),
641                                  ValueRange{ivVal, lbVal, stepVal});
642 }
643 
644 /// Ensure prerequisites that guarantee pad op hoisting can occur.
645 /// Return failure in the cases when we cannot perform hoisting; i.e. if either:
646 ///   1. There exists a use of `padTensorOp` that is not a linalg input operand.
647 ///   2. There isn't an enclosing `outermostEnclosingForOp` loop.
648 ///   3. There exists an op with a region that is dominated by
649 ///   `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a
650 ///    LinalgOp.
651 ///   4. There exists an op with side effects that is dominated by
652 ///   `outermostEnclosingForOp` and that isn't a LoopLikeInterface.
653 ///   5. The lower bound, upper bound and step of all the loops involved in the
654 ///   hoisting can be
655 ///
656 /// While ensuring prerequisites:
657 ///   1. Fill the `backwardSlice` to contain the topologically sorted ops
658 ///   dominated by `outermostEnclosingForOp`.
659 ///   2. Fill the `packingLoops` to contain only the enclosing loops of
660 ///   `backwardSlice` whose IV is actually used in computing padding. Loops that
661 ///   remain in `backwardSlice` but that are not in `packingLoops` are
662 ///   dimensions of reuse.
663 static LogicalResult
664 hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels,
665                                    SetVector<Operation *> &backwardSlice,
666                                    SetVector<Operation *> &packingLoops,
667                                    SmallVector<Value> &dynamicTensorSizes) {
668   // Bail on any use that isn't an input of a Linalg op.
669   // Hoisting of inplace updates happens after vectorization.
670   for (OpOperand &use : padTensorOp.result().getUses()) {
671     auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
672     if (!linalgUser || !linalgUser.isInputTensor(&use))
673       return failure();
674   }
675 
676   // Get at most nLevels of enclosing loops.
677   SmallVector<LoopLikeOpInterface> reverseEnclosingLoops;
678   Operation *outermostEnclosingForOp = nullptr,
679             *nextEnclosingForOp =
680                 padTensorOp->getParentOfType<LoopLikeOpInterface>();
681   while (nLevels-- > 0 && nextEnclosingForOp) {
682     outermostEnclosingForOp = nextEnclosingForOp;
683     reverseEnclosingLoops.push_back(outermostEnclosingForOp);
684     nextEnclosingForOp =
685         nextEnclosingForOp->getParentOfType<LoopLikeOpInterface>();
686   }
687   if (!outermostEnclosingForOp)
688     return failure();
689 
690   // Get the backwards slice from `padTensorOp` that is dominated by the
691   // outermost enclosing loop.
692   DominanceInfo domInfo(outermostEnclosingForOp);
693   getBackwardSlice(padTensorOp.getOperation(), &backwardSlice,
694                    [&](Operation *op) {
695                      return domInfo.dominates(outermostEnclosingForOp, op);
696                    });
697 
698   // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp.
699   if (llvm::any_of(backwardSlice, [](Operation *op) {
700         return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) &&
701                !isa<LinalgOp>(op);
702       }))
703     return failure();
704 
705   // Filter out the loops whose induction variable is not used to compute the
706   // padded result. As a first approximation, just look for IVs that have no use
707   // in the backwardSlice.
708   // These are the dimensions of reuse that we can exploit to reduce the amount
709   // of work / memory.
710   // TODO: would this optimization compose better as a canonicalization?
711   for (LoopLikeOpInterface loop : llvm::reverse(reverseEnclosingLoops)) {
712     auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
713     if (!forOp)
714       continue;
715     for (Operation *user : forOp.getInductionVar().getUsers()) {
716       if (backwardSlice.contains(user)) {
717         packingLoops.insert(forOp);
718         break;
719       }
720     }
721   }
722 
723   // Backward slice is a topologically sorted list of ops starting at
724   // `outermostEnclosingForOp`.
725   assert(outermostEnclosingForOp == backwardSlice.front());
726 
727   scf::ForOp outer = cast<scf::ForOp>(outermostEnclosingForOp);
728   if (llvm::any_of(packingLoops, [&](Operation *op) {
729         scf::ForOp forOp = cast<scf::ForOp>(op);
730         Value lb = forOp.lowerBound(), ub = forOp.upperBound(),
731               step = forOp.step();
732         return !isDefinedOutsideOrConstant(outer, lb) ||
733                !(isDefinedOutsideOrConstant(outer, ub) ||
734                  backwardsSliceOnlyHasOpsOfType<scf::ForOp, AffineMinOp,
735                                                 AffineApplyOp>(outer, ub)) ||
736                !isDefinedOutsideOrConstant(outer, step);
737       }))
738     return failure();
739 
740   // IP just before the outermost loop considered that we hoist above.
741   OpBuilder b(outermostEnclosingForOp);
742   dynamicTensorSizes =
743       llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) {
744         return buildLoopTripCount(b, cast<scf::ForOp>(outermostEnclosingForOp),
745                                   cast<scf::ForOp>(op));
746       }));
747   // Assert all loop trip counts can be computed.
748   if (!llvm::all_of(dynamicTensorSizes, [](Value v) { return v; }))
749     llvm_unreachable("loop independence prerequisite not met");
750   return success();
751 }
752 
753 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
754                                                   unsigned nLoops) {
755   SmallVector<Value> dynamicTensorSizes;
756   SetVector<Operation *> backwardSlice, packingLoops;
757   if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops,
758                                                 backwardSlice, packingLoops,
759                                                 dynamicTensorSizes)))
760     return failure();
761 
762   // Update actual number of loops, which may be smaller.
763   nLoops = packingLoops.size();
764 
765   Location loc = padTensorOp->getLoc();
766   RankedTensorType paddedTensorType = padTensorOp.getResultType();
767   unsigned paddedRank = paddedTensorType.getRank();
768 
769   // Backward slice is a topologically sorted list of ops starting at
770   // `outermostEnclosingForOp`.
771   Operation *outermostEnclosingForOp = backwardSlice.front();
772   // IP just before the outermost loop considered that we hoist above.
773   OpBuilder b(outermostEnclosingForOp);
774 
775   // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
776   // padding.
777   SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize);
778   // TODO: go grab dims when necessary, for now PadTensorOp returns a static
779   // tensor.
780   llvm::append_range(packedShape, paddedTensorType.getShape());
781   auto packedTensorType =
782       RankedTensorType::get(packedShape, paddedTensorType.getElementType());
783   Value packedTensor = b.create<linalg::InitTensorOp>(
784       loc, dynamicTensorSizes, packedTensorType.getShape(),
785       packedTensorType.getElementType());
786 
787   // Clone the operations involved in the backward slice, iteratively stepping
788   // into the loops that we encounter.
789   // The implementation proceeds in a stack-like fashion:
790   //   1. Iteratively clone and step into the loops, pushing the `packedTensor`
791   //      deeper in the stack.
792   //   2. Create a InsertSliceOp at the top of the stack.
793   //   3. Iteratively pop and yield the result of the InsertSliceOp across
794   //     the cloned loops.
795   SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings;
796   clonedLoopIvs.reserve(nLoops);
797   leadingPackedTensorIndexings.reserve(nLoops);
798   BlockAndValueMapping bvm;
799   // Insert `padTensorOp` into the backwardSlice so we clone it too.
800   backwardSlice.insert(padTensorOp);
801   // Stack step 1. iteratively clone loops and push `packedTensor`.
802   for (Operation *op : backwardSlice) {
803     // Specifically sit out in the extract_slice(packedTensor) case: this is the
804     // piece we seek to replace.
805     if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
806       if (bvm.lookupOrDefault(sliceOp.source()) == packedTensor)
807         continue;
808     auto effects = dyn_cast<MemoryEffectOpInterface>(op);
809     bool hasNoEffects = !effects || effects.hasNoEffect();
810     if (hasNoEffects &&
811         (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op))) {
812       b.clone(*op, bvm);
813       continue;
814     }
815     // TODO: support more cases as they appear.
816     auto forOp = dyn_cast<scf::ForOp>(op);
817     assert(forOp && "Expected scf::ForOp when hoisting pad ops");
818     // Unused loop, just skip it.
819     if (!packingLoops.contains(forOp))
820       continue;
821 
822     auto clonedForOp =
823         b.create<scf::ForOp>(loc, bvm.lookupOrDefault(forOp.lowerBound()),
824                              bvm.lookupOrDefault(forOp.upperBound()),
825                              bvm.lookupOrDefault(forOp.step()), packedTensor);
826     // Map the induction var, region args and results to the `clonedForOp`.
827     bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
828     bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());
829     bvm.map(forOp.getResults(), clonedForOp.getResults());
830     assert(clonedForOp->getNumRegions() == 1);
831     clonedLoopIvs.push_back(clonedForOp.getInductionVar());
832 
833     b.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
834     Value loopIndependentIterationCount = buildLoopIterationCount(
835         b, cast<scf::ForOp>(outermostEnclosingForOp), clonedForOp);
836     // Assert the loop-independent iteration count can be computed.
837     if (!loopIndependentIterationCount)
838       llvm_unreachable("loop independence prerequisite not met");
839     leadingPackedTensorIndexings.push_back(loopIndependentIterationCount);
840     packedTensor = clonedForOp.getRegionIterArgs().front();
841   }
842 
843   // Stack step 2. create InsertSliceOp at the top of the stack.
844   // offsets = [clonedLoopIvs, 0 .. 0].
845   SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(),
846                                     leadingPackedTensorIndexings.end());
847   offsets.append(paddedRank, b.getIndexAttr(0));
848   // sizes = [1 .. 1, paddedShape].
849   SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1));
850   for (int64_t sz : paddedTensorType.getShape()) {
851     // TODO: go grab dims when necessary, for now PadTensorOp returns a static
852     // tensor.
853     assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes");
854     sizes.push_back(b.getIndexAttr(sz));
855   }
856   // strides = [1 .. 1].
857   SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1));
858 
859   Value inserted =
860       b.create<tensor::InsertSliceOp>(loc, bvm.lookup(padTensorOp.result()),
861                                       packedTensor, offsets, sizes, strides);
862 
863   // Stack step 3. iteratively pop the stack and propagate the yield.
864   Value valueToYield = inserted;
865   for (Value iv : llvm::reverse(clonedLoopIvs)) {
866     auto forOp = scf::getForInductionVarOwner(iv);
867     b.setInsertionPointToEnd(&forOp.getRegion().front());
868     b.create<scf::YieldOp>(loc, valueToYield);
869     valueToYield = forOp.getResult(0);
870   }
871 
872   // Now the packed tensor is ready, replace the original padding op by a
873   // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
874   b.setInsertionPoint(padTensorOp);
875   SmallVector<Value> loopIterationCounts =
876       llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {
877         return buildLoopIterationCount(
878             b, cast<scf::ForOp>(outermostEnclosingForOp),
879             cast<scf::ForOp>(loop));
880       }));
881   // Assert all loop iteration counts can be computed.
882   if (llvm::any_of(loopIterationCounts, [](Value v) { return !v; }))
883     llvm_unreachable("loop independence prerequisite not met");
884   // offsets = [originalLoopIvs, 0 .. 0].
885   offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end());
886   offsets.append(paddedRank, b.getIndexAttr(0));
887   // sizes = [1 .. 1, paddedShape] (definedabove).
888   // strides = [1 .. 1] (defined above)
889   packedTensor =
890       scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
891   padTensorOp.replaceAllUsesWith(
892       b.create<tensor::ExtractSliceOp>(loc, padTensorOp.getResultType(),
893                                        packedTensor, offsets, sizes, strides)
894           ->getResult(0));
895 
896   Operation *toErase = padTensorOp;
897 
898   // Make the newly cloned `padTensorOp` available to the caller.
899   padTensorOp =
900       cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp());
901 
902   toErase->erase();
903 
904   return success();
905 }
906