1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/SCF/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/Transforms/LoopUtils.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "linalg-hoisting"
28 
29 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
30 
31 using namespace mlir;
32 using namespace mlir::linalg;
33 
34 using llvm::dbgs;
35 
36 void mlir::linalg::hoistViewAllocOps(FuncOp func) {
37   bool changed = true;
38   while (changed) {
39     changed = false;
40     func.walk([&changed](Operation *op) {
41       if (!isa<AllocOp, AllocaOp, DeallocOp>(op))
42         return;
43 
44       LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n");
45       auto loop = dyn_cast<scf::ForOp>(op->getParentOp());
46       LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n");
47 
48       // Only hoist out of immediately enclosing scf::ForOp.
49       if (!loop)
50         return;
51 
52       // If any operand is defined inside the loop don't hoist.
53       if (llvm::any_of(op->getOperands(), [&](Value v) {
54             return !loop.isDefinedOutsideOfLoop(v);
55           }))
56         return;
57 
58       LLVM_DEBUG(DBGS() << "All operands defined outside \n");
59 
60       // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist.
61       Value v;
62       if (op->getNumResults() > 0) {
63         assert(op->getNumResults() == 1 && "Unexpected multi-result alloc");
64         v = op->getResult(0);
65       }
66       if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) {
67             return isa<ViewLikeOpInterface, DeallocOp>(operand.getOwner());
68           })) {
69         LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n");
70         return;
71       }
72 
73       // Move AllocOp before the loop.
74       if (isa<AllocOp, AllocaOp>(op))
75         loop.moveOutOfLoop({op});
76       else // Move DeallocOp outside of the loop.
77         op->moveAfter(loop);
78       changed = true;
79     });
80   }
81 }
82 
83 /// Return true if we can prove that the transfer operations access dijoint
84 /// memory.
85 static bool isDisjoint(VectorTransferOpInterface transferA,
86                        VectorTransferOpInterface transferB) {
87   if (transferA.memref() != transferB.memref())
88     return false;
89   // For simplicity only look at transfer of same type.
90   if (transferA.getVectorType() != transferB.getVectorType())
91     return false;
92   unsigned rankOffset = transferA.getLeadingMemRefRank();
93   for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
94     auto indexA = transferA.indices()[i].getDefiningOp<ConstantOp>();
95     auto indexB = transferB.indices()[i].getDefiningOp<ConstantOp>();
96     // If any of the indices are dynamic we cannot prove anything.
97     if (!indexA || !indexB)
98       continue;
99 
100     if (i < rankOffset) {
101       // For dimension used as index if we can prove that index are different we
102       // know we are accessing disjoint slices.
103       if (indexA.getValue().cast<IntegerAttr>().getInt() !=
104           indexB.getValue().cast<IntegerAttr>().getInt())
105         return true;
106     } else {
107       // For this dimension, we slice a part of the memref we need to make sure
108       // the intervals accessed don't overlap.
109       int64_t distance =
110           std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
111                    indexB.getValue().cast<IntegerAttr>().getInt());
112       if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
113         return true;
114     }
115   }
116   return false;
117 }
118 
119 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
120   bool changed = true;
121   while (changed) {
122     changed = false;
123 
124     func.walk([&](vector::TransferReadOp transferRead) {
125       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
126                         << *transferRead.getOperation() << "\n");
127       auto loop = dyn_cast<scf::ForOp>(transferRead.getParentOp());
128       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead.getParentOp()
129                         << "\n");
130       if (!loop)
131         return WalkResult::advance();
132 
133       if (failed(moveLoopInvariantCode(
134               cast<LoopLikeOpInterface>(loop.getOperation()))))
135         llvm_unreachable(
136             "Unexpected failure to move invariant code out of loop");
137 
138       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
139                         << "\n");
140 
141       llvm::SetVector<Operation *> forwardSlice;
142       getForwardSlice(transferRead, &forwardSlice);
143 
144       // Look for the last TransferWriteOp in the forwardSlice of
145       // `transferRead` that operates on the same memref.
146       vector::TransferWriteOp transferWrite;
147       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
148         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
149         if (!candidateWrite || candidateWrite.memref() != transferRead.memref())
150           continue;
151         transferWrite = candidateWrite;
152       }
153 
154       // All operands of the TransferRead must be defined outside of the loop.
155       for (auto operand : transferRead.getOperands())
156         if (!loop.isDefinedOutsideOfLoop(operand))
157           return WalkResult::advance();
158 
159       // Only hoist transfer_read / transfer_write pairs for now.
160       if (!transferWrite)
161         return WalkResult::advance();
162 
163       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
164                         << "\n");
165 
166       // Approximate aliasing by checking that:
167       //   1. indices are the same,
168       //   2. no other operations in the loop access the same memref except
169       //      for transfer_read/transfer_write accessing statically disjoint
170       //      slices.
171       if (transferRead.indices() != transferWrite.indices() &&
172           transferRead.getVectorType() == transferWrite.getVectorType())
173         return WalkResult::advance();
174 
175       // TODO: may want to memoize this information for performance but it
176       // likely gets invalidated often.
177       DominanceInfo dom(loop);
178       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
179         return WalkResult::advance();
180       for (auto &use : transferRead.memref().getUses()) {
181         if (!dom.properlyDominates(loop, use.getOwner()))
182           continue;
183         if (use.getOwner() == transferRead.getOperation() ||
184             use.getOwner() == transferWrite.getOperation())
185           continue;
186         if (auto transferWriteUse =
187                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
188           if (!isDisjoint(
189                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
190                   cast<VectorTransferOpInterface>(
191                       transferWriteUse.getOperation())))
192             return WalkResult::advance();
193         } else if (auto transferReadUse =
194                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
195           if (!isDisjoint(
196                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
197                   cast<VectorTransferOpInterface>(
198                       transferReadUse.getOperation())))
199             return WalkResult::advance();
200         } else {
201           // Unknown use, we cannot prove that it doesn't alias with the
202           // transferRead/transferWrite operations.
203           return WalkResult::advance();
204         }
205       }
206 
207       // Hoist read before.
208       if (failed(loop.moveOutOfLoop({transferRead})))
209         llvm_unreachable(
210             "Unexpected failure to move transfer read out of loop");
211 
212       // Hoist write after.
213       transferWrite.getOperation()->moveAfter(loop);
214 
215       // Rewrite `loop` with new yields by cloning and erase the original loop.
216       OpBuilder b(transferRead);
217       auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
218                                          transferWrite.vector());
219 
220       // Transfer write has been hoisted, need to update the written value to
221       // the value yielded by the newForOp.
222       transferWrite.vector().replaceAllUsesWith(
223           newForOp.getResults().take_back()[0]);
224 
225       changed = true;
226       loop.erase();
227       // Need to interrupt and restart because erasing the loop messes up the
228       // walk.
229       return WalkResult::interrupt();
230     });
231   }
232 }
233