1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SCF/Utils/Utils.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/Dominance.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/Debug.h"
33
34 using llvm::dbgs;
35
36 #define DEBUG_TYPE "linalg-hoisting"
37
38 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
39
40 using namespace mlir;
41 using namespace mlir::linalg;
42
43 namespace {
44 /// Represents a unit of hoistable TransferWriteOp. This may comprise other
45 /// instructions that need to be hoisted too.
46 struct HoistableWrite {
47 vector::TransferWriteOp transferWriteOp;
48 tensor::InsertSliceOp insertSliceOp;
49 };
50 /// Represents a unit of hoistable TransferReadOp. This may comprise other
51 /// instructions that need to be hoisted too.
52 struct HoistableRead {
53 vector::TransferReadOp transferReadOp;
54 tensor::ExtractSliceOp extractSliceOp;
55 };
56 } // namespace
57
58 /// Return true if op1 and op2 are the same constant or the same SSA value.
isEqualOffsetSizeOrStride(OpFoldResult op1,OpFoldResult op2)59 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
60 auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
61 Attribute attr = ofr.dyn_cast<Attribute>();
62 // Note: isa+cast-like pattern allows writing the condition below as 1 line.
63 if (!attr && ofr.get<Value>().getDefiningOp<arith::ConstantOp>())
64 attr = ofr.get<Value>().getDefiningOp<arith::ConstantOp>().getValue();
65 if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
66 return intAttr.getValue().getSExtValue();
67 return llvm::None;
68 };
69 auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
70 if (cst1 && cst2 && *cst1 == *cst2)
71 return true;
72 auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
73 return v1 && v2 && v1 == v2;
74 }
75
76 /// Return true is all offsets, sizes and strides are equal.
sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s,tensor::InsertSliceOp si)77 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s,
78 tensor::InsertSliceOp si) {
79 if (s.getStaticOffsets().size() != si.getStaticOffsets().size())
80 return false;
81 if (s.getStaticSizes().size() != si.getStaticSizes().size())
82 return false;
83 if (s.getStaticStrides().size() != si.getStaticStrides().size())
84 return false;
85 for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
86 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
87 return false;
88 for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
89 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
90 return false;
91 for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
92 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
93 return false;
94 return true;
95 }
96
97 /// Look for a HoistableRead, in the given tensor uses, accessing the same
98 /// offset as the HoistableWrite.
findMatchingTransferRead(HoistableWrite write,Value srcTensor)99 static HoistableRead findMatchingTransferRead(HoistableWrite write,
100 Value srcTensor) {
101 assert(write.transferWriteOp &&
102 "expected hoistable write to have a .transfer_write");
103
104 LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
105 << *write.transferWriteOp.getOperation() << "\n");
106 if (write.insertSliceOp)
107 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: "
108 << *write.insertSliceOp.getOperation() << "\n");
109
110 for (Operation *user : srcTensor.getUsers()) {
111 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
112 << "\n");
113
114 // If HoistableWrite involves a InsertSliceOp, we need to find a
115 // matching ExtractSliceOp.
116 tensor::ExtractSliceOp sliceOp;
117 Operation *maybeTransferReadUser = user;
118 if (write.insertSliceOp) {
119 sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
120 if (!sliceOp || sliceOp.getResult().getType() !=
121 write.insertSliceOp.getSource().getType())
122 continue;
123
124 LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
125 << *sliceOp << " vs " << *write.insertSliceOp << "\n");
126 if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp))
127 continue;
128
129 LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
130 // If we got here, sliceOp is hoistable iff it has exactly 2 uses:
131 // 1. the transfer_write we want to hoist.
132 // 2. a matching transfer_read.
133 // Anything else, we skip.
134 bool skip = false;
135 Operation *otherUser = nullptr;
136 for (Operation *u : sliceOp->getUsers()) {
137 if (u == write.transferWriteOp)
138 continue;
139 if (otherUser) {
140 skip = true;
141 break;
142 }
143 otherUser = u;
144 }
145 if (skip || !otherUser)
146 continue;
147 maybeTransferReadUser = otherUser;
148 }
149
150 LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
151 << "\n");
152 auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
153 if (read && read.getIndices() == write.transferWriteOp.getIndices() &&
154 read.getVectorType() == write.transferWriteOp.getVectorType())
155 return HoistableRead{read, sliceOp};
156 }
157 return HoistableRead();
158 }
159
160 /// Check if the chunk of data inserted by the HoistableWrite are read by any
161 /// other op than the HoistableRead candidate.
tensorChunkAccessedByUnknownOp(HoistableWrite write,HoistableRead candidateRead,BlockArgument tensorArg)162 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
163 HoistableRead candidateRead,
164 BlockArgument tensorArg) {
165 // Make sure none of the other uses read the part of the tensor modified
166 // by the transfer_write.
167 llvm::SmallVector<Value::use_range, 1> uses;
168 uses.push_back(tensorArg.getUses());
169 while (!uses.empty()) {
170 for (OpOperand &use : uses.pop_back_val()) {
171 Operation *user = use.getOwner();
172 // Skip the candidate use, only inspect the "other" uses.
173 if (user == candidateRead.transferReadOp ||
174 user == candidateRead.extractSliceOp ||
175 user == write.transferWriteOp || user == write.insertSliceOp)
176 continue;
177 // Consider all transitive uses through a extract_slice / insert_slice.
178 // TODO: atm we just bail because a stronger analysis is needed for these
179 // cases.
180 if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
181 return true;
182 // Consider all transitive uses through a vector.transfer_write.
183 if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
184 uses.push_back(writeUser->getResult(0).getUses());
185 continue;
186 }
187 // Consider all nested uses through an scf::ForOp. We may have
188 // pass-through tensor arguments left from previous level of
189 // hoisting.
190 if (auto forUser = dyn_cast<scf::ForOp>(user)) {
191 Value arg = forUser.getLoopBody().getArgument(
192 use.getOperandNumber() - forUser.getNumControlOperands() +
193 /*iv value*/ 1);
194 uses.push_back(arg.getUses());
195 continue;
196 }
197 // Follow the use yield as long as it doesn't escape the original
198 // region.
199 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
200 if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
201 yieldUser->getParentOp())) {
202 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
203 uses.push_back(ret.getUses());
204 continue;
205 }
206 auto read = dyn_cast<vector::TransferReadOp>(user);
207 if (!read || !vector::isDisjointTransferIndices(
208 cast<VectorTransferOpInterface>(read.getOperation()),
209 cast<VectorTransferOpInterface>(
210 write.transferWriteOp.getOperation()))) {
211 return true;
212 }
213 }
214 }
215 return false;
216 }
217
218 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
219 /// Return the null HoistableWrite() if it is not comprised of a
220 /// vector.transfer_write + optional insert_slice or if any of the indexings
221 /// is `forOp`-dependent.
222 static HoistableWrite
getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,OpOperand & yieldOperand)223 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
224 OpOperand &yieldOperand) {
225 Value v = yieldOperand.get();
226 if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
227 // Indexing must not depend on `forOp`.
228 for (Value operand : write.getIndices())
229 if (!forOp.isDefinedOutsideOfLoop(operand))
230 return HoistableWrite();
231
232 return HoistableWrite{write, nullptr};
233 }
234
235 if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) {
236 // Inserted slice must come from vector.transfer_write.
237 auto write =
238 insertSliceOp.getSource().getDefiningOp<vector::TransferWriteOp>();
239 if (!write)
240 return HoistableWrite();
241
242 // Tensor inserted into must be a BBArg at position matching yieldOperand's.
243 auto bbArg = insertSliceOp.getDest().dyn_cast<BlockArgument>();
244 if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
245 bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
246 return HoistableWrite();
247
248 // Indexing inserted into must not depend on `forOp`.
249 for (Value operand : insertSliceOp->getOperands().drop_front(
250 tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
251 if (!forOp.isDefinedOutsideOfLoop(operand))
252 return HoistableWrite();
253
254 return HoistableWrite{write, insertSliceOp};
255 }
256
257 return HoistableWrite();
258 }
259
260 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
hoistReadWrite(HoistableRead read,HoistableWrite write,BlockArgument tensorBBArg)261 static void hoistReadWrite(HoistableRead read, HoistableWrite write,
262 BlockArgument tensorBBArg) {
263 scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
264 assert(read.transferReadOp && write.transferWriteOp &&
265 "expected transfer_read and transfer_write ops to be set");
266 assert(((read.extractSliceOp && write.insertSliceOp) ||
267 (!read.extractSliceOp && !write.insertSliceOp)) &&
268 "expected matching extract_slice / insert_slice");
269 LLVM_DEBUG(DBGS() << "In forOp:\n"
270 << *forOp.getOperation()
271 << "\nHoist: " << *read.transferReadOp.getOperation()
272 << "\nHoist: " << *write.transferWriteOp.getOperation()
273 << "\nInvolving: " << tensorBBArg << "\n");
274
275 // If a read slice is present, hoist it.
276 if (read.extractSliceOp)
277 forOp.moveOutOfLoop(read.extractSliceOp);
278
279 // Hoist the transfer_read op.
280 forOp.moveOutOfLoop(read.transferReadOp);
281
282 // TODO: don't hardcode /*numIvs=*/1.
283 assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
284 unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
285
286 // Update the source tensor.
287 if (read.extractSliceOp)
288 read.extractSliceOp.getSourceMutable().assign(
289 forOp.getInitArgs()[initArgNumber]);
290 else
291 read.transferReadOp.getSourceMutable().assign(
292 forOp.getInitArgs()[initArgNumber]);
293
294 // Hoist write after.
295 if (write.insertSliceOp)
296 write.insertSliceOp->moveAfter(forOp);
297 write.transferWriteOp->moveAfter(forOp);
298
299 // Update the yield.
300 auto yieldOp = cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
301 if (write.insertSliceOp)
302 yieldOp->setOperand(initArgNumber, write.insertSliceOp.getDest());
303 else
304 yieldOp->setOperand(initArgNumber, write.transferWriteOp.getSource());
305
306 // Rewrite `loop` with additional new yields.
307 OpBuilder b(read.transferReadOp);
308 NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
309 ArrayRef<BlockArgument> newBBArgs) {
310 return SmallVector<Value>{write.transferWriteOp.getVector()};
311 };
312 auto newForOp = replaceLoopWithNewYields(
313 b, forOp, read.transferReadOp.getVector(), yieldFn);
314
315 // Transfer write has been hoisted, need to update the vector and tensor
316 // source. Replace the result of the loop to use the new tensor created
317 // outside the loop.
318 // Depending on whether a insert_slice is present or not, it carries the
319 // update on the tensor operands.
320 if (write.insertSliceOp) {
321 newForOp.getResult(initArgNumber)
322 .replaceAllUsesWith(write.insertSliceOp.getResult());
323 write.transferWriteOp.getSourceMutable().assign(
324 read.extractSliceOp.getResult());
325 write.insertSliceOp.getDestMutable().assign(
326 read.extractSliceOp.getSource());
327 } else {
328 newForOp.getResult(initArgNumber)
329 .replaceAllUsesWith(write.transferWriteOp.getResult());
330 write.transferWriteOp.getSourceMutable().assign(
331 newForOp.getResult(initArgNumber));
332 }
333
334 // Always update with the newly yield tensor and vector.
335 write.transferWriteOp.getVectorMutable().assign(newForOp.getResults().back());
336 }
337
338 // To hoist transfer op on tensor the logic can be significantly simplified
339 // compared to the case on buffer. The transformation follows this logic:
340 // 1. Look for transfer_write with a single use from ForOp yield
341 // 2. Check the uses of the matching block argument and look for a transfer_read
342 // with the same indices.
343 // 3. Check that all the other uses of the tensor argument are either disjoint
344 // tensor_read or transfer_write. For transfer_write uses recurse to make sure
345 // the new tensor has the same restrictions on its uses.
346 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
347 // After this transformation the scf.forOp may have unused arguments that can be
348 // remove by the canonicalization pass.
hoistRedundantVectorTransfersOnTensor(func::FuncOp func)349 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) {
350 bool changed = true;
351 while (changed) {
352 changed = false;
353 func.walk([&](scf::ForOp forOp) {
354 Operation *yield = forOp.getBody()->getTerminator();
355 for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) {
356 OpOperand &ret = yield->getOpOperand(it.index());
357 HoistableWrite write =
358 getLoopInvariantTransferWriteOpDefining(forOp, ret);
359 if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
360 continue;
361 LLVM_DEBUG(dbgs() << "\n";
362 DBGS() << "Candidate write for hoisting: "
363 << *write.transferWriteOp.getOperation() << "\n");
364 if (write.insertSliceOp)
365 LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: "
366 << *write.insertSliceOp.getOperation() << "\n");
367 if (llvm::any_of(write.transferWriteOp.getIndices(),
368 [&forOp](Value index) {
369 return !forOp.isDefinedOutsideOfLoop(index);
370 }))
371 continue;
372 // Find a read with the same type and indices.
373 HoistableRead matchingRead =
374 findMatchingTransferRead(write, it.value());
375 // Make sure none of the other uses read the part of the tensor modified
376 // by the transfer_write.
377 if (!matchingRead.transferReadOp ||
378 tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
379 continue;
380
381 LLVM_DEBUG(DBGS() << "Start hoisting\n");
382 hoistReadWrite(matchingRead, write, it.value());
383 changed = true;
384 forOp.erase();
385
386 // Need to interrupt and restart: erasing the loop messes up the walk.
387 return WalkResult::interrupt();
388 }
389 return WalkResult::advance();
390 });
391 // Apply canonicalization so the newForOp + yield folds immediately, thus
392 // cleaning up the IR and potentially enabling more hoisting.
393 if (changed) {
394 RewritePatternSet patterns(func->getContext());
395 scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
396 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
397 }
398 }
399 }
400
hoistRedundantVectorTransfers(func::FuncOp func)401 void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
402 bool changed = true;
403 while (changed) {
404 changed = false;
405 // First move loop invariant ops outside of their loop. This needs to be
406 // done before as we cannot move ops without interrupting the function walk.
407 func.walk(
408 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
409
410 func.walk([&](vector::TransferReadOp transferRead) {
411 if (!transferRead.getShapedType().isa<MemRefType>())
412 return WalkResult::advance();
413
414 LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
415 << *transferRead.getOperation() << "\n");
416 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
417 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
418 << "\n");
419 if (!loop)
420 return WalkResult::advance();
421
422 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
423 << "\n");
424
425 SetVector<Operation *> forwardSlice;
426 getForwardSlice(transferRead.getOperation(), &forwardSlice);
427
428 // Look for the last TransferWriteOp in the forwardSlice of
429 // `transferRead` that operates on the same memref.
430 vector::TransferWriteOp transferWrite;
431 for (auto *sliceOp : llvm::reverse(forwardSlice)) {
432 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
433 if (!candidateWrite ||
434 candidateWrite.getSource() != transferRead.getSource())
435 continue;
436 transferWrite = candidateWrite;
437 }
438
439 // All operands of the TransferRead must be defined outside of the loop.
440 for (auto operand : transferRead.getOperands())
441 if (!loop.isDefinedOutsideOfLoop(operand))
442 return WalkResult::advance();
443
444 // Only hoist transfer_read / transfer_write pairs for now.
445 if (!transferWrite)
446 return WalkResult::advance();
447
448 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
449 << "\n");
450
451 // Approximate aliasing by checking that:
452 // 1. indices are the same,
453 // 2. no other operations in the loop access the same memref except
454 // for transfer_read/transfer_write accessing statically disjoint
455 // slices.
456 if (transferRead.getIndices() != transferWrite.getIndices() &&
457 transferRead.getVectorType() == transferWrite.getVectorType())
458 return WalkResult::advance();
459
460 // TODO: may want to memoize this information for performance but it
461 // likely gets invalidated often.
462 DominanceInfo dom(loop);
463 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
464 return WalkResult::advance();
465 for (auto &use : transferRead.getSource().getUses()) {
466 if (!loop->isAncestor(use.getOwner()))
467 continue;
468 if (use.getOwner() == transferRead.getOperation() ||
469 use.getOwner() == transferWrite.getOperation())
470 continue;
471 if (auto transferWriteUse =
472 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
473 if (!vector::isDisjointTransferSet(
474 cast<VectorTransferOpInterface>(transferWrite.getOperation()),
475 cast<VectorTransferOpInterface>(
476 transferWriteUse.getOperation())))
477 return WalkResult::advance();
478 } else if (auto transferReadUse =
479 dyn_cast<vector::TransferReadOp>(use.getOwner())) {
480 if (!vector::isDisjointTransferSet(
481 cast<VectorTransferOpInterface>(transferWrite.getOperation()),
482 cast<VectorTransferOpInterface>(
483 transferReadUse.getOperation())))
484 return WalkResult::advance();
485 } else {
486 // Unknown use, we cannot prove that it doesn't alias with the
487 // transferRead/transferWrite operations.
488 return WalkResult::advance();
489 }
490 }
491
492 // Hoist read before.
493 loop.moveOutOfLoop(transferRead);
494
495 // Hoist write after.
496 transferWrite->moveAfter(loop);
497
498 // Rewrite `loop` with new yields by cloning and erase the original loop.
499 OpBuilder b(transferRead);
500 NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
501 ArrayRef<BlockArgument> newBBArgs) {
502 return SmallVector<Value>{transferWrite.getVector()};
503 };
504 auto newForOp =
505 replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn);
506
507 // Transfer write has been hoisted, need to update the written vector by
508 // the value yielded by the newForOp.
509 transferWrite.getVectorMutable().assign(newForOp.getResults().back());
510
511 changed = true;
512 loop.erase();
513 // Need to interrupt and restart because erasing the loop messes up the
514 // walk.
515 return WalkResult::interrupt();
516 });
517 }
518 }
519