1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SCF/Utils/Utils.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Vector/IR/VectorOps.h"
24 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/Dominance.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Debug.h"
30 
31 using llvm::dbgs;
32 
33 #define DEBUG_TYPE "linalg-hoisting"
34 
35 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
36 
37 using namespace mlir;
38 using namespace mlir::linalg;
39 
40 namespace {
41 /// Represents a unit of hoistable TransferWriteOp. This may comprise other
42 /// instructions that need to be hoisted too.
43 struct HoistableWrite {
44   vector::TransferWriteOp transferWriteOp;
45   tensor::InsertSliceOp insertSliceOp;
46 };
47 /// Represents a unit of hoistable TransferReadOp. This may comprise other
48 /// instructions that need to be hoisted too.
49 struct HoistableRead {
50   vector::TransferReadOp transferReadOp;
51   tensor::ExtractSliceOp extractSliceOp;
52 };
53 } // namespace
54 
55 /// Return true if op1 and op2 are the same constant or the same SSA value.
56 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
57   auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
58     Attribute attr = ofr.dyn_cast<Attribute>();
59     // Note: isa+cast-like pattern allows writing the condition below as 1 line.
60     if (!attr && ofr.get<Value>().getDefiningOp<arith::ConstantOp>())
61       attr = ofr.get<Value>().getDefiningOp<arith::ConstantOp>().getValue();
62     if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
63       return intAttr.getValue().getSExtValue();
64     return llvm::None;
65   };
66   auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
67   if (cst1 && cst2 && *cst1 == *cst2)
68     return true;
69   auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
70   return v1 && v2 && v1 == v2;
71 }
72 
73 /// Return true is all offsets, sizes and strides are equal.
74 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s,
75                                        tensor::InsertSliceOp si) {
76   if (s.static_offsets().size() != si.static_offsets().size())
77     return false;
78   if (s.static_sizes().size() != si.static_sizes().size())
79     return false;
80   if (s.static_strides().size() != si.static_strides().size())
81     return false;
82   for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
83     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
84       return false;
85   for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
86     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
87       return false;
88   for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
89     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
90       return false;
91   return true;
92 }
93 
94 /// Look for a HoistableRead, in the given tensor uses, accessing the same
95 /// offset as the HoistableWrite.
96 static HoistableRead findMatchingTransferRead(HoistableWrite write,
97                                               Value srcTensor) {
98   assert(write.transferWriteOp &&
99          "expected hoistable write to have a .transfer_write");
100 
101   LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
102                     << *write.transferWriteOp.getOperation() << "\n");
103   if (write.insertSliceOp)
104     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: "
105                       << *write.insertSliceOp.getOperation() << "\n");
106 
107   for (Operation *user : srcTensor.getUsers()) {
108     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
109                       << "\n");
110 
111     // If HoistableWrite involves a InsertSliceOp, we need to find a
112     // matching ExtractSliceOp.
113     tensor::ExtractSliceOp sliceOp;
114     Operation *maybeTransferReadUser = user;
115     if (write.insertSliceOp) {
116       sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
117       if (!sliceOp || sliceOp.getResult().getType() !=
118                           write.insertSliceOp.source().getType())
119         continue;
120 
121       LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
122                         << *sliceOp << " vs " << *write.insertSliceOp << "\n");
123       if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp))
124         continue;
125 
126       LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
127       // If we got here, sliceOp is hoistable iff it has exactly 2 uses:
128       //   1. the transfer_write we want to hoist.
129       //   2. a matching transfer_read.
130       // Anything else, we skip.
131       bool skip = false;
132       Operation *otherUser = nullptr;
133       for (Operation *u : sliceOp->getUsers()) {
134         if (u == write.transferWriteOp)
135           continue;
136         if (otherUser) {
137           skip = true;
138           break;
139         }
140         otherUser = u;
141       }
142       if (skip || !otherUser)
143         continue;
144       maybeTransferReadUser = otherUser;
145     }
146 
147     LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
148                       << "\n");
149     auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
150     if (read && read.indices() == write.transferWriteOp.indices() &&
151         read.getVectorType() == write.transferWriteOp.getVectorType())
152       return HoistableRead{read, sliceOp};
153   }
154   return HoistableRead();
155 }
156 
157 /// Check if the chunk of data inserted by the HoistableWrite are read by any
158 /// other op than the HoistableRead candidate.
159 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
160                                            HoistableRead candidateRead,
161                                            BlockArgument tensorArg) {
162   // Make sure none of the other uses read the part of the tensor modified
163   // by the transfer_write.
164   llvm::SmallVector<Value::use_range, 1> uses;
165   uses.push_back(tensorArg.getUses());
166   while (!uses.empty()) {
167     for (OpOperand &use : uses.pop_back_val()) {
168       Operation *user = use.getOwner();
169       // Skip the candidate use, only inspect the "other" uses.
170       if (user == candidateRead.transferReadOp ||
171           user == candidateRead.extractSliceOp ||
172           user == write.transferWriteOp || user == write.insertSliceOp)
173         continue;
174       // Consider all transitive uses through a extract_slice / insert_slice.
175       // TODO: atm we just bail because a stronger analysis is needed for these
176       // cases.
177       if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
178         return true;
179       // Consider all transitive uses through a vector.transfer_write.
180       if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
181         uses.push_back(writeUser->getResult(0).getUses());
182         continue;
183       }
184       // Consider all nested uses through an scf::ForOp. We may have
185       // pass-through tensor arguments left from previous level of
186       // hoisting.
187       if (auto forUser = dyn_cast<scf::ForOp>(user)) {
188         Value arg = forUser.getLoopBody().getArgument(
189             use.getOperandNumber() - forUser.getNumControlOperands() +
190             /*iv value*/ 1);
191         uses.push_back(arg.getUses());
192         continue;
193       }
194       // Follow the use yield as long as it doesn't escape the original
195       // region.
196       scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
197       if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
198                            yieldUser->getParentOp())) {
199         Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
200         uses.push_back(ret.getUses());
201         continue;
202       }
203       auto read = dyn_cast<vector::TransferReadOp>(user);
204       if (!read || !vector::isDisjointTransferIndices(
205                        cast<VectorTransferOpInterface>(read.getOperation()),
206                        cast<VectorTransferOpInterface>(
207                            write.transferWriteOp.getOperation()))) {
208         return true;
209       }
210     }
211   }
212   return false;
213 }
214 
215 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
216 /// Return the null HoistableWrite() if it is not comprised of a
217 /// vector.transfer_write + optional insert_slice or if any of the indexings
218 /// is `forOp`-dependent.
219 static HoistableWrite
220 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
221                                         OpOperand &yieldOperand) {
222   Value v = yieldOperand.get();
223   if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
224     // Indexing must not depend on `forOp`.
225     for (Value operand : write.indices())
226       if (!forOp.isDefinedOutsideOfLoop(operand))
227         return HoistableWrite();
228 
229     return HoistableWrite{write, nullptr};
230   }
231 
232   if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) {
233     // Inserted slice must come from vector.transfer_write.
234     auto write =
235         insertSliceOp.source().getDefiningOp<vector::TransferWriteOp>();
236     if (!write)
237       return HoistableWrite();
238 
239     // Tensor inserted into must be a BBArg at position matching yieldOperand's.
240     auto bbArg = insertSliceOp.dest().dyn_cast<BlockArgument>();
241     if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
242         bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
243       return HoistableWrite();
244 
245     // Indexing inserted into must not depend on `forOp`.
246     for (Value operand : insertSliceOp->getOperands().drop_front(
247              tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
248       if (!forOp.isDefinedOutsideOfLoop(operand))
249         return HoistableWrite();
250 
251     return HoistableWrite{write, insertSliceOp};
252   }
253 
254   return HoistableWrite();
255 }
256 
257 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
258 static void hoistReadWrite(HoistableRead read, HoistableWrite write,
259                            BlockArgument tensorBBArg) {
260   scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
261   assert(read.transferReadOp && write.transferWriteOp &&
262          "expected transfer_read and transfer_write ops to be set");
263   assert(((read.extractSliceOp && write.insertSliceOp) ||
264           (!read.extractSliceOp && !write.insertSliceOp)) &&
265          "expected matching extract_slice / insert_slice");
266   LLVM_DEBUG(DBGS() << "In forOp:\n"
267                     << *forOp.getOperation()
268                     << "\nHoist: " << *read.transferReadOp.getOperation()
269                     << "\nHoist: " << *write.transferWriteOp.getOperation()
270                     << "\nInvolving: " << tensorBBArg << "\n");
271 
272   // If a read slice is present, hoist it.
273   if (read.extractSliceOp && failed(forOp.moveOutOfLoop({read.extractSliceOp})))
274     llvm_unreachable("Unexpected failure moving extract_slice out of loop");
275 
276   // Hoist the transfer_read op.
277   if (failed(forOp.moveOutOfLoop({read.transferReadOp})))
278     llvm_unreachable("Unexpected failure moving transfer read out of loop");
279 
280   // TODO: don't hardcode /*numIvs=*/1.
281   assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
282   unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
283 
284   // Update the source tensor.
285   if (read.extractSliceOp)
286     read.extractSliceOp.sourceMutable().assign(
287         forOp.getInitArgs()[initArgNumber]);
288   else
289     read.transferReadOp.sourceMutable().assign(
290         forOp.getInitArgs()[initArgNumber]);
291 
292   // Hoist write after.
293   if (write.insertSliceOp)
294     write.insertSliceOp->moveAfter(forOp);
295   write.transferWriteOp->moveAfter(forOp);
296 
297   // Update the yield.
298   auto yieldOp = cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
299   if (write.insertSliceOp)
300     yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest());
301   else
302     yieldOp->setOperand(initArgNumber, write.transferWriteOp.source());
303 
304   // Rewrite `loop` with additional new yields.
305   OpBuilder b(read.transferReadOp);
306   auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(),
307                                      write.transferWriteOp.vector());
308   // Transfer write has been hoisted, need to update the vector and tensor
309   // source. Replace the result of the loop to use the new tensor created
310   // outside the loop.
311   // Depending on whether a insert_slice is present or not, it carries the
312   // update on the tensor operands.
313   if (write.insertSliceOp) {
314     newForOp.getResult(initArgNumber)
315         .replaceAllUsesWith(write.insertSliceOp.getResult());
316     write.transferWriteOp.sourceMutable().assign(read.extractSliceOp.result());
317     write.insertSliceOp.destMutable().assign(read.extractSliceOp.source());
318   } else {
319     newForOp.getResult(initArgNumber)
320         .replaceAllUsesWith(write.transferWriteOp.getResult());
321     write.transferWriteOp.sourceMutable().assign(
322         newForOp.getResult(initArgNumber));
323   }
324 
325   // Always update with the newly yield tensor and vector.
326   write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back());
327 }
328 
329 // To hoist transfer op on tensor the logic can be significantly simplified
330 // compared to the case on buffer. The transformation follows this logic:
331 // 1. Look for transfer_write with a single use from ForOp yield
332 // 2. Check the uses of the matching block argument and look for a transfer_read
333 // with the same indices.
334 // 3. Check that all the other uses of the tensor argument are either disjoint
335 // tensor_read or transfer_write. For transfer_write uses recurse to make sure
336 // the new tensor has the same restrictions on its uses.
337 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
338 // After this transformation the scf.forOp may have unused arguments that can be
339 // remove by the canonicalization pass.
340 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
341   bool changed = true;
342   while (changed) {
343     changed = false;
344     func.walk([&](scf::ForOp forOp) {
345       Operation *yield = forOp.getBody()->getTerminator();
346       for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) {
347         OpOperand &ret = yield->getOpOperand(it.index());
348         HoistableWrite write =
349             getLoopInvariantTransferWriteOpDefining(forOp, ret);
350         if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
351           continue;
352         LLVM_DEBUG(dbgs() << "\n";
353                    DBGS() << "Candidate write for hoisting: "
354                           << *write.transferWriteOp.getOperation() << "\n");
355         if (write.insertSliceOp)
356           LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: "
357                             << *write.insertSliceOp.getOperation() << "\n");
358         if (llvm::any_of(write.transferWriteOp.indices(),
359                          [&forOp](Value index) {
360                            return !forOp.isDefinedOutsideOfLoop(index);
361                          }))
362           continue;
363         // Find a read with the same type and indices.
364         HoistableRead matchingRead =
365             findMatchingTransferRead(write, it.value());
366         // Make sure none of the other uses read the part of the tensor modified
367         // by the transfer_write.
368         if (!matchingRead.transferReadOp ||
369             tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
370           continue;
371 
372         LLVM_DEBUG(DBGS() << "Start hoisting\n");
373         hoistReadWrite(matchingRead, write, it.value());
374         changed = true;
375         forOp.erase();
376 
377         // Need to interrupt and restart: erasing the loop messes up the walk.
378         return WalkResult::interrupt();
379       }
380       return WalkResult::advance();
381     });
382     // Apply canonicalization so the newForOp + yield folds immediately, thus
383     // cleaning up the IR and potentially enabling more hoisting.
384     if (changed) {
385       RewritePatternSet patterns(func->getContext());
386       scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
387       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
388     }
389   }
390 }
391 
392 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
393   bool changed = true;
394   while (changed) {
395     changed = false;
396     // First move loop invariant ops outside of their loop. This needs to be
397     // done before as we cannot move ops without interputing the function walk.
398     func.walk([&](LoopLikeOpInterface loopLike) {
399       if (failed(moveLoopInvariantCode(loopLike)))
400         llvm_unreachable(
401             "Unexpected failure to move invariant code out of loop");
402     });
403 
404     func.walk([&](vector::TransferReadOp transferRead) {
405       if (!transferRead.getShapedType().isa<MemRefType>())
406         return WalkResult::advance();
407 
408       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
409                         << *transferRead.getOperation() << "\n");
410       auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
411       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
412                         << "\n");
413       if (!loop)
414         return WalkResult::advance();
415 
416       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
417                         << "\n");
418 
419       SetVector<Operation *> forwardSlice;
420       getForwardSlice(transferRead.getOperation(), &forwardSlice);
421 
422       // Look for the last TransferWriteOp in the forwardSlice of
423       // `transferRead` that operates on the same memref.
424       vector::TransferWriteOp transferWrite;
425       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
426         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
427         if (!candidateWrite || candidateWrite.source() != transferRead.source())
428           continue;
429         transferWrite = candidateWrite;
430       }
431 
432       // All operands of the TransferRead must be defined outside of the loop.
433       for (auto operand : transferRead.getOperands())
434         if (!loop.isDefinedOutsideOfLoop(operand))
435           return WalkResult::advance();
436 
437       // Only hoist transfer_read / transfer_write pairs for now.
438       if (!transferWrite)
439         return WalkResult::advance();
440 
441       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
442                         << "\n");
443 
444       // Approximate aliasing by checking that:
445       //   1. indices are the same,
446       //   2. no other operations in the loop access the same memref except
447       //      for transfer_read/transfer_write accessing statically disjoint
448       //      slices.
449       if (transferRead.indices() != transferWrite.indices() &&
450           transferRead.getVectorType() == transferWrite.getVectorType())
451         return WalkResult::advance();
452 
453       // TODO: may want to memoize this information for performance but it
454       // likely gets invalidated often.
455       DominanceInfo dom(loop);
456       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
457         return WalkResult::advance();
458       for (auto &use : transferRead.source().getUses()) {
459         if (!loop->isAncestor(use.getOwner()))
460           continue;
461         if (use.getOwner() == transferRead.getOperation() ||
462             use.getOwner() == transferWrite.getOperation())
463           continue;
464         if (auto transferWriteUse =
465                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
466           if (!vector::isDisjointTransferSet(
467                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
468                   cast<VectorTransferOpInterface>(
469                       transferWriteUse.getOperation())))
470             return WalkResult::advance();
471         } else if (auto transferReadUse =
472                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
473           if (!vector::isDisjointTransferSet(
474                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
475                   cast<VectorTransferOpInterface>(
476                       transferReadUse.getOperation())))
477             return WalkResult::advance();
478         } else {
479           // Unknown use, we cannot prove that it doesn't alias with the
480           // transferRead/transferWrite operations.
481           return WalkResult::advance();
482         }
483       }
484 
485       // Hoist read before.
486       if (failed(loop.moveOutOfLoop({transferRead})))
487         llvm_unreachable(
488             "Unexpected failure to move transfer read out of loop");
489 
490       // Hoist write after.
491       transferWrite->moveAfter(loop);
492 
493       // Rewrite `loop` with new yields by cloning and erase the original loop.
494       OpBuilder b(transferRead);
495       auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
496                                          transferWrite.vector());
497 
498       // Transfer write has been hoisted, need to update the written value to
499       // the value yielded by the newForOp.
500       transferWrite.vector().replaceAllUsesWith(
501           newForOp.getResults().take_back()[0]);
502 
503       changed = true;
504       loop.erase();
505       // Need to interrupt and restart because erasing the loop messes up the
506       // walk.
507       return WalkResult::interrupt();
508     });
509   }
510 }
511