1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/SCF/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/Dialect/Vector/VectorUtils.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Dominance.h"
24 #include "mlir/Transforms/LoopUtils.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "linalg-hoisting"
29 
30 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 
35 using llvm::dbgs;
36 
37 void mlir::linalg::hoistViewAllocOps(FuncOp func) {
38   bool changed = true;
39   while (changed) {
40     changed = false;
41     func.walk([&changed](Operation *op) {
42       if (!isa<AllocOp, AllocaOp, DeallocOp>(op))
43         return;
44 
45       LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n");
46       auto loop = dyn_cast<scf::ForOp>(op->getParentOp());
47       LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n");
48 
49       // Only hoist out of immediately enclosing scf::ForOp.
50       if (!loop)
51         return;
52 
53       // If any operand is defined inside the loop don't hoist.
54       if (llvm::any_of(op->getOperands(), [&](Value v) {
55             return !loop.isDefinedOutsideOfLoop(v);
56           }))
57         return;
58 
59       LLVM_DEBUG(DBGS() << "All operands defined outside \n");
60 
61       // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist.
62       Value v;
63       if (op->getNumResults() > 0) {
64         assert(op->getNumResults() == 1 && "Unexpected multi-result alloc");
65         v = op->getResult(0);
66       }
67       if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) {
68             return isa<ViewLikeOpInterface, DeallocOp>(operand.getOwner());
69           })) {
70         LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n");
71         return;
72       }
73 
74       // Move AllocOp before the loop.
75       if (isa<AllocOp, AllocaOp>(op))
76         loop.moveOutOfLoop({op});
77       else // Move DeallocOp outside of the loop.
78         op->moveAfter(loop);
79       changed = true;
80     });
81   }
82 }
83 
84 /// Look for a transfer_read, in the given tensor uses, accessing the same
85 /// offset as the transfer_write.
86 static vector::TransferReadOp
87 findMatchingTransferRead(vector::TransferWriteOp write, Value srcTensor) {
88   for (Operation *user : srcTensor.getUsers()) {
89     auto read = dyn_cast<vector::TransferReadOp>(user);
90     if (read && read.indices() == write.indices() &&
91         read.getVectorType() == write.getVectorType()) {
92       return read;
93     }
94   }
95   return nullptr;
96 }
97 
98 /// Check if the chunk of data inserted by the transfer_write in the given
99 /// tensor are read by any other op than the read candidate.
100 static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write,
101                                            vector::TransferReadOp candidateRead,
102                                            Value srcTensor) {
103   // Make sure none of the other uses read the part of the tensor modified
104   // by the transfer_write.
105   llvm::SmallVector<Value::use_range, 1> uses;
106   uses.push_back(srcTensor.getUses());
107   while (!uses.empty()) {
108     for (OpOperand &use : uses.pop_back_val()) {
109       Operation *user = use.getOwner();
110       // Skip the candidate use, only inspect the "other" uses.
111       if (user == candidateRead.getOperation() || user == write.getOperation())
112         continue;
113       // Consider all transitive uses through a vector.transfer_write.
114       if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
115         uses.push_back(writeUser->getResult(0).getUses());
116         continue;
117       }
118       // Consider all nested uses through an scf::ForOp. We may have
119       // pass-through tensor arguments left from previous level of
120       // hoisting.
121       if (auto forUser = dyn_cast<scf::ForOp>(user)) {
122         Value arg = forUser.getLoopBody().getArgument(
123             use.getOperandNumber() - forUser.getNumControlOperands() +
124             /*iv value*/ 1);
125         uses.push_back(arg.getUses());
126         continue;
127       }
128       // Follow the use yield as long as it doesn't escape the original
129       // region.
130       scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
131       if (yieldUser &&
132           write->getParentOp()->isAncestor(yieldUser->getParentOp())) {
133         Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
134         uses.push_back(ret.getUses());
135         continue;
136       }
137       auto read = dyn_cast<vector::TransferReadOp>(user);
138       if (!read || !isDisjointTransferIndices(
139                        cast<VectorTransferOpInterface>(read.getOperation()),
140                        cast<VectorTransferOpInterface>(write.getOperation()))) {
141         return true;
142       }
143     }
144   }
145   return false;
146 }
147 
148 // To hoist transfer op on tensor the logic can be significantly simplified
149 // compared to the case on buffer. The transformation follows this logic:
150 // 1. Look for transfer_write with a single use from ForOp yield
151 // 2. Check the uses of the matching block argument and look for a transfer_read
152 // with the same indices.
153 // 3. Check that all the other uses of the tensor argument are either disjoint
154 // tensor_read or transfer_write. For transfer_write uses recurse to make sure
155 // the new tensor has the same restrictions on its uses.
156 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
157 // After this transformation the scf.forOp may have unused arguments that can be
158 // remove by the canonicalization pass.
159 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
160   bool changed = true;
161   while (changed) {
162     changed = false;
163     func.walk([&](scf::ForOp forOp) {
164       Operation *yield = forOp.getBody()->getTerminator();
165       for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
166         Value ret = yield->getOperand(it.index());
167         auto write = ret.getDefiningOp<vector::TransferWriteOp>();
168         if (!write || !write->hasOneUse())
169           continue;
170         LLVM_DEBUG(DBGS() << "Candidate write for hoisting: "
171                           << *write.getOperation() << "\n");
172         if (llvm::any_of(write.indices(), [&forOp](Value index) {
173               return !forOp.isDefinedOutsideOfLoop(index);
174             }))
175           continue;
176         // Find a read with the same type and indices.
177         vector::TransferReadOp matchingRead =
178             findMatchingTransferRead(write, it.value());
179         // Make sure none of the other uses read the part of the tensor modified
180         // by the transfer_write.
181         if (!matchingRead ||
182             tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
183           continue;
184 
185         // Hoist read before.
186         if (failed(forOp.moveOutOfLoop({matchingRead})))
187           llvm_unreachable(
188               "Unexpected failure to move transfer read out of loop");
189         // Update the source tensor.
190         matchingRead.sourceMutable().assign(forOp.initArgs()[it.index()]);
191 
192         // Hoist write after.
193         write->moveAfter(forOp);
194         yield->setOperand(it.index(), write.source());
195 
196         // Rewrite `loop` with new yields by cloning and erase the original
197         // loop.
198         OpBuilder b(matchingRead);
199         auto newForOp =
200             cloneWithNewYields(b, forOp, matchingRead.vector(), write.vector());
201 
202         // Transfer write has been hoisted, need to update the vector and tensor
203         // source. Replace the result of the loop to use the new tensor created
204         // outside the loop.
205         newForOp.getResult(it.index()).replaceAllUsesWith(write.getResult(0));
206         write.vectorMutable().assign(newForOp.getResults().back());
207         write.sourceMutable().assign(newForOp.getResult(it.index()));
208 
209         changed = true;
210         forOp.erase();
211         // Need to interrupt and restart because erasing the loop messes up the
212         // walk.
213         return WalkResult::interrupt();
214       }
215       return WalkResult::advance();
216     });
217   }
218 }
219 
220 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
221   bool changed = true;
222   while (changed) {
223     changed = false;
224 
225     func.walk([&](vector::TransferReadOp transferRead) {
226       if (!transferRead.getShapedType().isa<MemRefType>())
227         return WalkResult::advance();
228 
229       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
230                         << *transferRead.getOperation() << "\n");
231       auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
232       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
233                         << "\n");
234       if (!loop)
235         return WalkResult::advance();
236 
237       if (failed(moveLoopInvariantCode(
238               cast<LoopLikeOpInterface>(loop.getOperation()))))
239         llvm_unreachable(
240             "Unexpected failure to move invariant code out of loop");
241 
242       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
243                         << "\n");
244 
245       llvm::SetVector<Operation *> forwardSlice;
246       getForwardSlice(transferRead, &forwardSlice);
247 
248       // Look for the last TransferWriteOp in the forwardSlice of
249       // `transferRead` that operates on the same memref.
250       vector::TransferWriteOp transferWrite;
251       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
252         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
253         if (!candidateWrite || candidateWrite.source() != transferRead.source())
254           continue;
255         transferWrite = candidateWrite;
256       }
257 
258       // All operands of the TransferRead must be defined outside of the loop.
259       for (auto operand : transferRead.getOperands())
260         if (!loop.isDefinedOutsideOfLoop(operand))
261           return WalkResult::advance();
262 
263       // Only hoist transfer_read / transfer_write pairs for now.
264       if (!transferWrite)
265         return WalkResult::advance();
266 
267       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
268                         << "\n");
269 
270       // Approximate aliasing by checking that:
271       //   1. indices are the same,
272       //   2. no other operations in the loop access the same memref except
273       //      for transfer_read/transfer_write accessing statically disjoint
274       //      slices.
275       if (transferRead.indices() != transferWrite.indices() &&
276           transferRead.getVectorType() == transferWrite.getVectorType())
277         return WalkResult::advance();
278 
279       // TODO: may want to memoize this information for performance but it
280       // likely gets invalidated often.
281       DominanceInfo dom(loop);
282       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
283         return WalkResult::advance();
284       for (auto &use : transferRead.source().getUses()) {
285         if (!dom.properlyDominates(loop, use.getOwner()))
286           continue;
287         if (use.getOwner() == transferRead.getOperation() ||
288             use.getOwner() == transferWrite.getOperation())
289           continue;
290         if (auto transferWriteUse =
291                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
292           if (!isDisjointTransferSet(
293                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
294                   cast<VectorTransferOpInterface>(
295                       transferWriteUse.getOperation())))
296             return WalkResult::advance();
297         } else if (auto transferReadUse =
298                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
299           if (!isDisjointTransferSet(
300                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
301                   cast<VectorTransferOpInterface>(
302                       transferReadUse.getOperation())))
303             return WalkResult::advance();
304         } else {
305           // Unknown use, we cannot prove that it doesn't alias with the
306           // transferRead/transferWrite operations.
307           return WalkResult::advance();
308         }
309       }
310 
311       // Hoist read before.
312       if (failed(loop.moveOutOfLoop({transferRead})))
313         llvm_unreachable(
314             "Unexpected failure to move transfer read out of loop");
315 
316       // Hoist write after.
317       transferWrite->moveAfter(loop);
318 
319       // Rewrite `loop` with new yields by cloning and erase the original loop.
320       OpBuilder b(transferRead);
321       auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
322                                          transferWrite.vector());
323 
324       // Transfer write has been hoisted, need to update the written value to
325       // the value yielded by the newForOp.
326       transferWrite.vector().replaceAllUsesWith(
327           newForOp.getResults().take_back()[0]);
328 
329       changed = true;
330       loop.erase();
331       // Need to interrupt and restart because erasing the loop messes up the
332       // walk.
333       return WalkResult::interrupt();
334     });
335   }
336 }
337