1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/SCF/SCF.h"
23 #include "mlir/Dialect/SCF/Utils/Utils.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/Dominance.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/Debug.h"
33 
34 using llvm::dbgs;
35 
36 #define DEBUG_TYPE "linalg-hoisting"
37 
38 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
39 
40 using namespace mlir;
41 using namespace mlir::linalg;
42 
43 namespace {
44 /// Represents a unit of hoistable TransferWriteOp. This may comprise other
45 /// instructions that need to be hoisted too.
46 struct HoistableWrite {
47   vector::TransferWriteOp transferWriteOp;
48   tensor::InsertSliceOp insertSliceOp;
49 };
50 /// Represents a unit of hoistable TransferReadOp. This may comprise other
51 /// instructions that need to be hoisted too.
52 struct HoistableRead {
53   vector::TransferReadOp transferReadOp;
54   tensor::ExtractSliceOp extractSliceOp;
55 };
56 } // namespace
57 
58 /// Return true if op1 and op2 are the same constant or the same SSA value.
59 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
60   auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
61     Attribute attr = ofr.dyn_cast<Attribute>();
62     // Note: isa+cast-like pattern allows writing the condition below as 1 line.
63     if (!attr && ofr.get<Value>().getDefiningOp<arith::ConstantOp>())
64       attr = ofr.get<Value>().getDefiningOp<arith::ConstantOp>().getValue();
65     if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
66       return intAttr.getValue().getSExtValue();
67     return llvm::None;
68   };
69   auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
70   if (cst1 && cst2 && *cst1 == *cst2)
71     return true;
72   auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
73   return v1 && v2 && v1 == v2;
74 }
75 
76 /// Return true is all offsets, sizes and strides are equal.
77 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s,
78                                        tensor::InsertSliceOp si) {
79   if (s.static_offsets().size() != si.static_offsets().size())
80     return false;
81   if (s.static_sizes().size() != si.static_sizes().size())
82     return false;
83   if (s.static_strides().size() != si.static_strides().size())
84     return false;
85   for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
86     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
87       return false;
88   for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
89     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
90       return false;
91   for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
92     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
93       return false;
94   return true;
95 }
96 
97 /// Look for a HoistableRead, in the given tensor uses, accessing the same
98 /// offset as the HoistableWrite.
99 static HoistableRead findMatchingTransferRead(HoistableWrite write,
100                                               Value srcTensor) {
101   assert(write.transferWriteOp &&
102          "expected hoistable write to have a .transfer_write");
103 
104   LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
105                     << *write.transferWriteOp.getOperation() << "\n");
106   if (write.insertSliceOp)
107     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: "
108                       << *write.insertSliceOp.getOperation() << "\n");
109 
110   for (Operation *user : srcTensor.getUsers()) {
111     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
112                       << "\n");
113 
114     // If HoistableWrite involves a InsertSliceOp, we need to find a
115     // matching ExtractSliceOp.
116     tensor::ExtractSliceOp sliceOp;
117     Operation *maybeTransferReadUser = user;
118     if (write.insertSliceOp) {
119       sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
120       if (!sliceOp || sliceOp.getResult().getType() !=
121                           write.insertSliceOp.source().getType())
122         continue;
123 
124       LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
125                         << *sliceOp << " vs " << *write.insertSliceOp << "\n");
126       if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp))
127         continue;
128 
129       LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
130       // If we got here, sliceOp is hoistable iff it has exactly 2 uses:
131       //   1. the transfer_write we want to hoist.
132       //   2. a matching transfer_read.
133       // Anything else, we skip.
134       bool skip = false;
135       Operation *otherUser = nullptr;
136       for (Operation *u : sliceOp->getUsers()) {
137         if (u == write.transferWriteOp)
138           continue;
139         if (otherUser) {
140           skip = true;
141           break;
142         }
143         otherUser = u;
144       }
145       if (skip || !otherUser)
146         continue;
147       maybeTransferReadUser = otherUser;
148     }
149 
150     LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
151                       << "\n");
152     auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
153     if (read && read.getIndices() == write.transferWriteOp.getIndices() &&
154         read.getVectorType() == write.transferWriteOp.getVectorType())
155       return HoistableRead{read, sliceOp};
156   }
157   return HoistableRead();
158 }
159 
160 /// Check if the chunk of data inserted by the HoistableWrite are read by any
161 /// other op than the HoistableRead candidate.
162 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
163                                            HoistableRead candidateRead,
164                                            BlockArgument tensorArg) {
165   // Make sure none of the other uses read the part of the tensor modified
166   // by the transfer_write.
167   llvm::SmallVector<Value::use_range, 1> uses;
168   uses.push_back(tensorArg.getUses());
169   while (!uses.empty()) {
170     for (OpOperand &use : uses.pop_back_val()) {
171       Operation *user = use.getOwner();
172       // Skip the candidate use, only inspect the "other" uses.
173       if (user == candidateRead.transferReadOp ||
174           user == candidateRead.extractSliceOp ||
175           user == write.transferWriteOp || user == write.insertSliceOp)
176         continue;
177       // Consider all transitive uses through a extract_slice / insert_slice.
178       // TODO: atm we just bail because a stronger analysis is needed for these
179       // cases.
180       if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
181         return true;
182       // Consider all transitive uses through a vector.transfer_write.
183       if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
184         uses.push_back(writeUser->getResult(0).getUses());
185         continue;
186       }
187       // Consider all nested uses through an scf::ForOp. We may have
188       // pass-through tensor arguments left from previous level of
189       // hoisting.
190       if (auto forUser = dyn_cast<scf::ForOp>(user)) {
191         Value arg = forUser.getLoopBody().getArgument(
192             use.getOperandNumber() - forUser.getNumControlOperands() +
193             /*iv value*/ 1);
194         uses.push_back(arg.getUses());
195         continue;
196       }
197       // Follow the use yield as long as it doesn't escape the original
198       // region.
199       scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
200       if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
201                            yieldUser->getParentOp())) {
202         Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
203         uses.push_back(ret.getUses());
204         continue;
205       }
206       auto read = dyn_cast<vector::TransferReadOp>(user);
207       if (!read || !vector::isDisjointTransferIndices(
208                        cast<VectorTransferOpInterface>(read.getOperation()),
209                        cast<VectorTransferOpInterface>(
210                            write.transferWriteOp.getOperation()))) {
211         return true;
212       }
213     }
214   }
215   return false;
216 }
217 
218 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
219 /// Return the null HoistableWrite() if it is not comprised of a
220 /// vector.transfer_write + optional insert_slice or if any of the indexings
221 /// is `forOp`-dependent.
222 static HoistableWrite
223 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
224                                         OpOperand &yieldOperand) {
225   Value v = yieldOperand.get();
226   if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
227     // Indexing must not depend on `forOp`.
228     for (Value operand : write.getIndices())
229       if (!forOp.isDefinedOutsideOfLoop(operand))
230         return HoistableWrite();
231 
232     return HoistableWrite{write, nullptr};
233   }
234 
235   if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) {
236     // Inserted slice must come from vector.transfer_write.
237     auto write =
238         insertSliceOp.source().getDefiningOp<vector::TransferWriteOp>();
239     if (!write)
240       return HoistableWrite();
241 
242     // Tensor inserted into must be a BBArg at position matching yieldOperand's.
243     auto bbArg = insertSliceOp.dest().dyn_cast<BlockArgument>();
244     if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
245         bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
246       return HoistableWrite();
247 
248     // Indexing inserted into must not depend on `forOp`.
249     for (Value operand : insertSliceOp->getOperands().drop_front(
250              tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
251       if (!forOp.isDefinedOutsideOfLoop(operand))
252         return HoistableWrite();
253 
254     return HoistableWrite{write, insertSliceOp};
255   }
256 
257   return HoistableWrite();
258 }
259 
260 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
261 static void hoistReadWrite(HoistableRead read, HoistableWrite write,
262                            BlockArgument tensorBBArg) {
263   scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
264   assert(read.transferReadOp && write.transferWriteOp &&
265          "expected transfer_read and transfer_write ops to be set");
266   assert(((read.extractSliceOp && write.insertSliceOp) ||
267           (!read.extractSliceOp && !write.insertSliceOp)) &&
268          "expected matching extract_slice / insert_slice");
269   LLVM_DEBUG(DBGS() << "In forOp:\n"
270                     << *forOp.getOperation()
271                     << "\nHoist: " << *read.transferReadOp.getOperation()
272                     << "\nHoist: " << *write.transferWriteOp.getOperation()
273                     << "\nInvolving: " << tensorBBArg << "\n");
274 
275   // If a read slice is present, hoist it.
276   if (read.extractSliceOp)
277     forOp.moveOutOfLoop(read.extractSliceOp);
278 
279   // Hoist the transfer_read op.
280   forOp.moveOutOfLoop(read.transferReadOp);
281 
282   // TODO: don't hardcode /*numIvs=*/1.
283   assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
284   unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
285 
286   // Update the source tensor.
287   if (read.extractSliceOp)
288     read.extractSliceOp.sourceMutable().assign(
289         forOp.getInitArgs()[initArgNumber]);
290   else
291     read.transferReadOp.getSourceMutable().assign(
292         forOp.getInitArgs()[initArgNumber]);
293 
294   // Hoist write after.
295   if (write.insertSliceOp)
296     write.insertSliceOp->moveAfter(forOp);
297   write.transferWriteOp->moveAfter(forOp);
298 
299   // Update the yield.
300   auto yieldOp = cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
301   if (write.insertSliceOp)
302     yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest());
303   else
304     yieldOp->setOperand(initArgNumber, write.transferWriteOp.getSource());
305 
306   // Rewrite `loop` with additional new yields.
307   OpBuilder b(read.transferReadOp);
308   NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
309                                 ArrayRef<BlockArgument> newBBArgs) {
310     return SmallVector<Value>{write.transferWriteOp.getVector()};
311   };
312   auto newForOp = replaceLoopWithNewYields(
313       b, forOp, read.transferReadOp.getVector(), yieldFn);
314 
315   // Transfer write has been hoisted, need to update the vector and tensor
316   // source. Replace the result of the loop to use the new tensor created
317   // outside the loop.
318   // Depending on whether a insert_slice is present or not, it carries the
319   // update on the tensor operands.
320   if (write.insertSliceOp) {
321     newForOp.getResult(initArgNumber)
322         .replaceAllUsesWith(write.insertSliceOp.getResult());
323     write.transferWriteOp.getSourceMutable().assign(
324         read.extractSliceOp.result());
325     write.insertSliceOp.destMutable().assign(read.extractSliceOp.source());
326   } else {
327     newForOp.getResult(initArgNumber)
328         .replaceAllUsesWith(write.transferWriteOp.getResult());
329     write.transferWriteOp.getSourceMutable().assign(
330         newForOp.getResult(initArgNumber));
331   }
332 
333   // Always update with the newly yield tensor and vector.
334   write.transferWriteOp.getVectorMutable().assign(newForOp.getResults().back());
335 }
336 
337 // To hoist transfer op on tensor the logic can be significantly simplified
338 // compared to the case on buffer. The transformation follows this logic:
339 // 1. Look for transfer_write with a single use from ForOp yield
340 // 2. Check the uses of the matching block argument and look for a transfer_read
341 // with the same indices.
342 // 3. Check that all the other uses of the tensor argument are either disjoint
343 // tensor_read or transfer_write. For transfer_write uses recurse to make sure
344 // the new tensor has the same restrictions on its uses.
345 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
346 // After this transformation the scf.forOp may have unused arguments that can be
347 // remove by the canonicalization pass.
348 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) {
349   bool changed = true;
350   while (changed) {
351     changed = false;
352     func.walk([&](scf::ForOp forOp) {
353       Operation *yield = forOp.getBody()->getTerminator();
354       for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) {
355         OpOperand &ret = yield->getOpOperand(it.index());
356         HoistableWrite write =
357             getLoopInvariantTransferWriteOpDefining(forOp, ret);
358         if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
359           continue;
360         LLVM_DEBUG(dbgs() << "\n";
361                    DBGS() << "Candidate write for hoisting: "
362                           << *write.transferWriteOp.getOperation() << "\n");
363         if (write.insertSliceOp)
364           LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: "
365                             << *write.insertSliceOp.getOperation() << "\n");
366         if (llvm::any_of(write.transferWriteOp.getIndices(),
367                          [&forOp](Value index) {
368                            return !forOp.isDefinedOutsideOfLoop(index);
369                          }))
370           continue;
371         // Find a read with the same type and indices.
372         HoistableRead matchingRead =
373             findMatchingTransferRead(write, it.value());
374         // Make sure none of the other uses read the part of the tensor modified
375         // by the transfer_write.
376         if (!matchingRead.transferReadOp ||
377             tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
378           continue;
379 
380         LLVM_DEBUG(DBGS() << "Start hoisting\n");
381         hoistReadWrite(matchingRead, write, it.value());
382         changed = true;
383         forOp.erase();
384 
385         // Need to interrupt and restart: erasing the loop messes up the walk.
386         return WalkResult::interrupt();
387       }
388       return WalkResult::advance();
389     });
390     // Apply canonicalization so the newForOp + yield folds immediately, thus
391     // cleaning up the IR and potentially enabling more hoisting.
392     if (changed) {
393       RewritePatternSet patterns(func->getContext());
394       scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
395       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
396     }
397   }
398 }
399 
400 void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
401   bool changed = true;
402   while (changed) {
403     changed = false;
404     // First move loop invariant ops outside of their loop. This needs to be
405     // done before as we cannot move ops without interrupting the function walk.
406     func.walk(
407         [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
408 
409     func.walk([&](vector::TransferReadOp transferRead) {
410       if (!transferRead.getShapedType().isa<MemRefType>())
411         return WalkResult::advance();
412 
413       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
414                         << *transferRead.getOperation() << "\n");
415       auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
416       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
417                         << "\n");
418       if (!loop)
419         return WalkResult::advance();
420 
421       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
422                         << "\n");
423 
424       SetVector<Operation *> forwardSlice;
425       getForwardSlice(transferRead.getOperation(), &forwardSlice);
426 
427       // Look for the last TransferWriteOp in the forwardSlice of
428       // `transferRead` that operates on the same memref.
429       vector::TransferWriteOp transferWrite;
430       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
431         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
432         if (!candidateWrite ||
433             candidateWrite.getSource() != transferRead.getSource())
434           continue;
435         transferWrite = candidateWrite;
436       }
437 
438       // All operands of the TransferRead must be defined outside of the loop.
439       for (auto operand : transferRead.getOperands())
440         if (!loop.isDefinedOutsideOfLoop(operand))
441           return WalkResult::advance();
442 
443       // Only hoist transfer_read / transfer_write pairs for now.
444       if (!transferWrite)
445         return WalkResult::advance();
446 
447       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
448                         << "\n");
449 
450       // Approximate aliasing by checking that:
451       //   1. indices are the same,
452       //   2. no other operations in the loop access the same memref except
453       //      for transfer_read/transfer_write accessing statically disjoint
454       //      slices.
455       if (transferRead.getIndices() != transferWrite.getIndices() &&
456           transferRead.getVectorType() == transferWrite.getVectorType())
457         return WalkResult::advance();
458 
459       // TODO: may want to memoize this information for performance but it
460       // likely gets invalidated often.
461       DominanceInfo dom(loop);
462       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
463         return WalkResult::advance();
464       for (auto &use : transferRead.getSource().getUses()) {
465         if (!loop->isAncestor(use.getOwner()))
466           continue;
467         if (use.getOwner() == transferRead.getOperation() ||
468             use.getOwner() == transferWrite.getOperation())
469           continue;
470         if (auto transferWriteUse =
471                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
472           if (!vector::isDisjointTransferSet(
473                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
474                   cast<VectorTransferOpInterface>(
475                       transferWriteUse.getOperation())))
476             return WalkResult::advance();
477         } else if (auto transferReadUse =
478                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
479           if (!vector::isDisjointTransferSet(
480                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
481                   cast<VectorTransferOpInterface>(
482                       transferReadUse.getOperation())))
483             return WalkResult::advance();
484         } else {
485           // Unknown use, we cannot prove that it doesn't alias with the
486           // transferRead/transferWrite operations.
487           return WalkResult::advance();
488         }
489       }
490 
491       // Hoist read before.
492       loop.moveOutOfLoop(transferRead);
493 
494       // Hoist write after.
495       transferWrite->moveAfter(loop);
496 
497       // Rewrite `loop` with new yields by cloning and erase the original loop.
498       OpBuilder b(transferRead);
499       NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
500                                     ArrayRef<BlockArgument> newBBArgs) {
501         return SmallVector<Value>{transferWrite.getVector()};
502       };
503       auto newForOp =
504           replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn);
505 
506       // Transfer write has been hoisted, need to update the written vector by
507       // the value yielded by the newForOp.
508       transferWrite.getVectorMutable().assign(newForOp.getResults().back());
509 
510       changed = true;
511       loop.erase();
512       // Need to interrupt and restart because erasing the loop messes up the
513       // walk.
514       return WalkResult::interrupt();
515     });
516   }
517 }
518