1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/SCF/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/IR/Function.h"
23 #include "mlir/Transforms/LoopUtils.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Debug.h"
26 
27 #define DEBUG_TYPE "linalg-hoisting"
28 
29 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
30 
31 using namespace mlir;
32 using namespace mlir::linalg;
33 
34 using llvm::dbgs;
35 
36 void mlir::linalg::hoistViewAllocOps(FuncOp func) {
37   bool changed = true;
38   while (changed) {
39     changed = false;
40     func.walk([&changed](Operation *op) {
41       if (!isa<AllocOp>(op) && !isa<AllocaOp>(op) && !isa<DeallocOp>(op))
42         return;
43 
44       LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n");
45       auto loop = dyn_cast<scf::ForOp>(op->getParentOp());
46       LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n");
47 
48       // Only hoist out of immediately enclosing scf::ForOp.
49       if (!loop)
50         return;
51 
52       // If any operand is defined inside the loop don't hoist.
53       if (llvm::any_of(op->getOperands(), [&](Value v) {
54             return !loop.isDefinedOutsideOfLoop(v);
55           }))
56         return;
57 
58       LLVM_DEBUG(DBGS() << "All operands defined outside \n");
59 
60       // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist.
61       Value v;
62       if (op->getNumResults() > 0) {
63         assert(op->getNumResults() == 1 && "Unexpected multi-result alloc");
64         v = op->getResult(0);
65       }
66       if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) {
67             return isa<ViewLikeOpInterface>(operand.getOwner()) ||
68                    isa<DeallocOp>(operand.getOwner());
69           })) {
70         LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n");
71         return;
72       }
73 
74       // Move AllocOp before the loop.
75       if (isa<AllocOp>(op) || isa<AllocaOp>(op))
76         loop.moveOutOfLoop({op});
77       else // Move DeallocOp outside of the loop.
78         op->moveAfter(loop);
79       changed = true;
80     });
81   }
82 }
83 
84 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
85   bool changed = true;
86   while (changed) {
87     changed = false;
88 
89     func.walk([&](vector::TransferReadOp transferRead) {
90       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
91                         << *transferRead.getOperation() << "\n");
92       auto loop = dyn_cast<scf::ForOp>(transferRead.getParentOp());
93       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead.getParentOp()
94                         << "\n");
95       if (!loop)
96         return WalkResult::advance();
97 
98       if (failed(moveLoopInvariantCode(
99               cast<LoopLikeOpInterface>(loop.getOperation()))))
100         llvm_unreachable(
101             "Unexpected failure to move invariant code out of loop");
102 
103       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
104                         << "\n");
105 
106       llvm::SetVector<Operation *> forwardSlice;
107       getForwardSlice(transferRead, &forwardSlice);
108 
109       // Look for the last TransferWriteOp in the forwardSlice of
110       // `transferRead` that operates on the same memref.
111       vector::TransferWriteOp transferWrite;
112       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
113         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
114         if (!candidateWrite || candidateWrite.memref() != transferRead.memref())
115           continue;
116         transferWrite = candidateWrite;
117       }
118 
119       // All operands of the TransferRead must be defined outside of the loop.
120       for (auto operand : transferRead.getOperands())
121         if (!loop.isDefinedOutsideOfLoop(operand))
122           return WalkResult::advance();
123 
124       // Only hoist transfer_read / transfer_write pairs for now.
125       if (!transferWrite)
126         return WalkResult::advance();
127 
128       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
129                         << "\n");
130 
131       // Approximate aliasing by checking that:
132       //   1. indices are the same,
133       //   2. no other use either dominates the transfer_read or is dominated
134       //   by the transfer_write (i.e. aliasing between the write and the read
135       //   across the loop).
136       if (transferRead.indices() != transferWrite.indices())
137         return WalkResult::advance();
138 
139       // TODO: may want to memoize this information for performance but it
140       // likely gets invalidated often.
141       DominanceInfo dom(loop);
142       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
143         return WalkResult::advance();
144       for (auto &use : transferRead.memref().getUses())
145         if (dom.properlyDominates(use.getOwner(),
146                                   transferRead.getOperation()) ||
147             dom.properlyDominates(transferWrite, use.getOwner()))
148           return WalkResult::advance();
149 
150       // Hoist read before.
151       if (failed(loop.moveOutOfLoop({transferRead})))
152         llvm_unreachable(
153             "Unexpected failure to move transfer read out of loop");
154 
155       // Hoist write after.
156       transferWrite.getOperation()->moveAfter(loop);
157 
158       // Rewrite `loop` with new yields by cloning and erase the original loop.
159       OpBuilder b(transferRead);
160       auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
161                                          transferWrite.vector());
162 
163       // Transfer write has been hoisted, need to update the written value to
164       // the value yielded by the newForOp.
165       transferWrite.vector().replaceAllUsesWith(
166           newForOp.getResults().take_back()[0]);
167 
168       changed = true;
169       loop.erase();
170       // Need to interrupt and restart because erasing the loop messes up the
171       // walk.
172       return WalkResult::interrupt();
173     });
174   }
175 }
176