1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant operations
10 // in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/SCF/Utils/Utils.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26 #include "mlir/IR/BuiltinOps.h"
27 #include "mlir/IR/Dominance.h"
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Support/Debug.h"
31 
32 using llvm::dbgs;
33 
34 #define DEBUG_TYPE "linalg-hoisting"
35 
36 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
37 
38 using namespace mlir;
39 using namespace mlir::linalg;
40 
41 namespace {
42 /// Represents a unit of hoistable TransferWriteOp. This may comprise other
43 /// instructions that need to be hoisted too.
44 struct HoistableWrite {
45   vector::TransferWriteOp transferWriteOp;
46   tensor::InsertSliceOp insertSliceOp;
47 };
48 /// Represents a unit of hoistable TransferReadOp. This may comprise other
49 /// instructions that need to be hoisted too.
50 struct HoistableRead {
51   vector::TransferReadOp transferReadOp;
52   tensor::ExtractSliceOp extractSliceOp;
53 };
54 } // namespace
55 
56 /// Return true if op1 and op2 are the same constant or the same SSA value.
57 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
58   auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
59     Attribute attr = ofr.dyn_cast<Attribute>();
60     // Note: isa+cast-like pattern allows writing the condition below as 1 line.
61     if (!attr && ofr.get<Value>().getDefiningOp<arith::ConstantOp>())
62       attr = ofr.get<Value>().getDefiningOp<arith::ConstantOp>().getValue();
63     if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
64       return intAttr.getValue().getSExtValue();
65     return llvm::None;
66   };
67   auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
68   if (cst1 && cst2 && *cst1 == *cst2)
69     return true;
70   auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
71   return v1 && v2 && v1 == v2;
72 }
73 
74 /// Return true is all offsets, sizes and strides are equal.
75 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s,
76                                        tensor::InsertSliceOp si) {
77   if (s.static_offsets().size() != si.static_offsets().size())
78     return false;
79   if (s.static_sizes().size() != si.static_sizes().size())
80     return false;
81   if (s.static_strides().size() != si.static_strides().size())
82     return false;
83   for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
84     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
85       return false;
86   for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
87     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
88       return false;
89   for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
90     if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
91       return false;
92   return true;
93 }
94 
95 /// Look for a HoistableRead, in the given tensor uses, accessing the same
96 /// offset as the HoistableWrite.
97 static HoistableRead findMatchingTransferRead(HoistableWrite write,
98                                               Value srcTensor) {
99   assert(write.transferWriteOp &&
100          "expected hoistable write to have a .transfer_write");
101 
102   LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
103                     << *write.transferWriteOp.getOperation() << "\n");
104   if (write.insertSliceOp)
105     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: "
106                       << *write.insertSliceOp.getOperation() << "\n");
107 
108   for (Operation *user : srcTensor.getUsers()) {
109     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
110                       << "\n");
111 
112     // If HoistableWrite involves a InsertSliceOp, we need to find a
113     // matching ExtractSliceOp.
114     tensor::ExtractSliceOp sliceOp;
115     Operation *maybeTransferReadUser = user;
116     if (write.insertSliceOp) {
117       sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
118       if (!sliceOp || sliceOp.getResult().getType() !=
119                           write.insertSliceOp.source().getType())
120         continue;
121 
122       LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
123                         << *sliceOp << " vs " << *write.insertSliceOp << "\n");
124       if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp))
125         continue;
126 
127       LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
128       // If we got here, sliceOp is hoistable iff it has exactly 2 uses:
129       //   1. the transfer_write we want to hoist.
130       //   2. a matching transfer_read.
131       // Anything else, we skip.
132       bool skip = false;
133       Operation *otherUser = nullptr;
134       for (Operation *u : sliceOp->getUsers()) {
135         if (u == write.transferWriteOp)
136           continue;
137         if (otherUser) {
138           skip = true;
139           break;
140         }
141         otherUser = u;
142       }
143       if (skip || !otherUser)
144         continue;
145       maybeTransferReadUser = otherUser;
146     }
147 
148     LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
149                       << "\n");
150     auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
151     if (read && read.indices() == write.transferWriteOp.indices() &&
152         read.getVectorType() == write.transferWriteOp.getVectorType())
153       return HoistableRead{read, sliceOp};
154   }
155   return HoistableRead();
156 }
157 
158 /// Check if the chunk of data inserted by the HoistableWrite are read by any
159 /// other op than the HoistableRead candidate.
160 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
161                                            HoistableRead candidateRead,
162                                            BlockArgument tensorArg) {
163   // Make sure none of the other uses read the part of the tensor modified
164   // by the transfer_write.
165   llvm::SmallVector<Value::use_range, 1> uses;
166   uses.push_back(tensorArg.getUses());
167   while (!uses.empty()) {
168     for (OpOperand &use : uses.pop_back_val()) {
169       Operation *user = use.getOwner();
170       // Skip the candidate use, only inspect the "other" uses.
171       if (user == candidateRead.transferReadOp ||
172           user == candidateRead.extractSliceOp ||
173           user == write.transferWriteOp || user == write.insertSliceOp)
174         continue;
175       // Consider all transitive uses through a extract_slice / insert_slice.
176       // TODO: atm we just bail because a stronger analysis is needed for these
177       // cases.
178       if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
179         return true;
180       // Consider all transitive uses through a vector.transfer_write.
181       if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
182         uses.push_back(writeUser->getResult(0).getUses());
183         continue;
184       }
185       // Consider all nested uses through an scf::ForOp. We may have
186       // pass-through tensor arguments left from previous level of
187       // hoisting.
188       if (auto forUser = dyn_cast<scf::ForOp>(user)) {
189         Value arg = forUser.getLoopBody().getArgument(
190             use.getOperandNumber() - forUser.getNumControlOperands() +
191             /*iv value*/ 1);
192         uses.push_back(arg.getUses());
193         continue;
194       }
195       // Follow the use yield as long as it doesn't escape the original
196       // region.
197       scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
198       if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
199                            yieldUser->getParentOp())) {
200         Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
201         uses.push_back(ret.getUses());
202         continue;
203       }
204       auto read = dyn_cast<vector::TransferReadOp>(user);
205       if (!read || !vector::isDisjointTransferIndices(
206                        cast<VectorTransferOpInterface>(read.getOperation()),
207                        cast<VectorTransferOpInterface>(
208                            write.transferWriteOp.getOperation()))) {
209         return true;
210       }
211     }
212   }
213   return false;
214 }
215 
216 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
217 /// Return the null HoistableWrite() if it is not comprised of a
218 /// vector.transfer_write + optional insert_slice or if any of the indexings
219 /// is `forOp`-dependent.
220 static HoistableWrite
221 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
222                                         OpOperand &yieldOperand) {
223   Value v = yieldOperand.get();
224   if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
225     // Indexing must not depend on `forOp`.
226     for (Value operand : write.indices())
227       if (!forOp.isDefinedOutsideOfLoop(operand))
228         return HoistableWrite();
229 
230     return HoistableWrite{write, nullptr};
231   }
232 
233   if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) {
234     // Inserted slice must come from vector.transfer_write.
235     auto write =
236         insertSliceOp.source().getDefiningOp<vector::TransferWriteOp>();
237     if (!write)
238       return HoistableWrite();
239 
240     // Tensor inserted into must be a BBArg at position matching yieldOperand's.
241     auto bbArg = insertSliceOp.dest().dyn_cast<BlockArgument>();
242     if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
243         bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
244       return HoistableWrite();
245 
246     // Indexing inserted into must not depend on `forOp`.
247     for (Value operand : insertSliceOp->getOperands().drop_front(
248              tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
249       if (!forOp.isDefinedOutsideOfLoop(operand))
250         return HoistableWrite();
251 
252     return HoistableWrite{write, insertSliceOp};
253   }
254 
255   return HoistableWrite();
256 }
257 
258 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
259 static void hoistReadWrite(HoistableRead read, HoistableWrite write,
260                            BlockArgument tensorBBArg) {
261   scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
262   assert(read.transferReadOp && write.transferWriteOp &&
263          "expected transfer_read and transfer_write ops to be set");
264   assert(((read.extractSliceOp && write.insertSliceOp) ||
265           (!read.extractSliceOp && !write.insertSliceOp)) &&
266          "expected matching extract_slice / insert_slice");
267   LLVM_DEBUG(DBGS() << "In forOp:\n"
268                     << *forOp.getOperation()
269                     << "\nHoist: " << *read.transferReadOp.getOperation()
270                     << "\nHoist: " << *write.transferWriteOp.getOperation()
271                     << "\nInvolving: " << tensorBBArg << "\n");
272 
273   // If a read slice is present, hoist it.
274   if (read.extractSliceOp && failed(forOp.moveOutOfLoop({read.extractSliceOp})))
275     llvm_unreachable("Unexpected failure moving extract_slice out of loop");
276 
277   // Hoist the transfer_read op.
278   if (failed(forOp.moveOutOfLoop({read.transferReadOp})))
279     llvm_unreachable("Unexpected failure moving transfer read out of loop");
280 
281   // TODO: don't hardcode /*numIvs=*/1.
282   assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
283   unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
284 
285   // Update the source tensor.
286   if (read.extractSliceOp)
287     read.extractSliceOp.sourceMutable().assign(
288         forOp.getInitArgs()[initArgNumber]);
289   else
290     read.transferReadOp.sourceMutable().assign(
291         forOp.getInitArgs()[initArgNumber]);
292 
293   // Hoist write after.
294   if (write.insertSliceOp)
295     write.insertSliceOp->moveAfter(forOp);
296   write.transferWriteOp->moveAfter(forOp);
297 
298   // Update the yield.
299   auto yieldOp = cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
300   if (write.insertSliceOp)
301     yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest());
302   else
303     yieldOp->setOperand(initArgNumber, write.transferWriteOp.source());
304 
305   // Rewrite `loop` with additional new yields.
306   OpBuilder b(read.transferReadOp);
307   auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(),
308                                      write.transferWriteOp.vector());
309   // Transfer write has been hoisted, need to update the vector and tensor
310   // source. Replace the result of the loop to use the new tensor created
311   // outside the loop.
312   // Depending on whether a insert_slice is present or not, it carries the
313   // update on the tensor operands.
314   if (write.insertSliceOp) {
315     newForOp.getResult(initArgNumber)
316         .replaceAllUsesWith(write.insertSliceOp.getResult());
317     write.transferWriteOp.sourceMutable().assign(read.extractSliceOp.result());
318     write.insertSliceOp.destMutable().assign(read.extractSliceOp.source());
319   } else {
320     newForOp.getResult(initArgNumber)
321         .replaceAllUsesWith(write.transferWriteOp.getResult());
322     write.transferWriteOp.sourceMutable().assign(
323         newForOp.getResult(initArgNumber));
324   }
325 
326   // Always update with the newly yield tensor and vector.
327   write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back());
328 }
329 
330 // To hoist transfer op on tensor the logic can be significantly simplified
331 // compared to the case on buffer. The transformation follows this logic:
332 // 1. Look for transfer_write with a single use from ForOp yield
333 // 2. Check the uses of the matching block argument and look for a transfer_read
334 // with the same indices.
335 // 3. Check that all the other uses of the tensor argument are either disjoint
336 // tensor_read or transfer_write. For transfer_write uses recurse to make sure
337 // the new tensor has the same restrictions on its uses.
338 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
339 // After this transformation the scf.forOp may have unused arguments that can be
340 // remove by the canonicalization pass.
341 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
342   bool changed = true;
343   while (changed) {
344     changed = false;
345     func.walk([&](scf::ForOp forOp) {
346       Operation *yield = forOp.getBody()->getTerminator();
347       for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) {
348         OpOperand &ret = yield->getOpOperand(it.index());
349         HoistableWrite write =
350             getLoopInvariantTransferWriteOpDefining(forOp, ret);
351         if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
352           continue;
353         LLVM_DEBUG(dbgs() << "\n";
354                    DBGS() << "Candidate write for hoisting: "
355                           << *write.transferWriteOp.getOperation() << "\n");
356         if (write.insertSliceOp)
357           LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: "
358                             << *write.insertSliceOp.getOperation() << "\n");
359         if (llvm::any_of(write.transferWriteOp.indices(),
360                          [&forOp](Value index) {
361                            return !forOp.isDefinedOutsideOfLoop(index);
362                          }))
363           continue;
364         // Find a read with the same type and indices.
365         HoistableRead matchingRead =
366             findMatchingTransferRead(write, it.value());
367         // Make sure none of the other uses read the part of the tensor modified
368         // by the transfer_write.
369         if (!matchingRead.transferReadOp ||
370             tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
371           continue;
372 
373         LLVM_DEBUG(DBGS() << "Start hoisting\n");
374         hoistReadWrite(matchingRead, write, it.value());
375         changed = true;
376         forOp.erase();
377 
378         // Need to interrupt and restart: erasing the loop messes up the walk.
379         return WalkResult::interrupt();
380       }
381       return WalkResult::advance();
382     });
383     // Apply canonicalization so the newForOp + yield folds immediately, thus
384     // cleaning up the IR and potentially enabling more hoisting.
385     if (changed) {
386       RewritePatternSet patterns(func->getContext());
387       scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
388       (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
389     }
390   }
391 }
392 
393 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
394   bool changed = true;
395   while (changed) {
396     changed = false;
397     // First move loop invariant ops outside of their loop. This needs to be
398     // done before as we cannot move ops without interputing the function walk.
399     func.walk([&](LoopLikeOpInterface loopLike) {
400       if (failed(moveLoopInvariantCode(loopLike)))
401         llvm_unreachable(
402             "Unexpected failure to move invariant code out of loop");
403     });
404 
405     func.walk([&](vector::TransferReadOp transferRead) {
406       if (!transferRead.getShapedType().isa<MemRefType>())
407         return WalkResult::advance();
408 
409       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
410                         << *transferRead.getOperation() << "\n");
411       auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
412       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
413                         << "\n");
414       if (!loop)
415         return WalkResult::advance();
416 
417       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
418                         << "\n");
419 
420       SetVector<Operation *> forwardSlice;
421       getForwardSlice(transferRead.getOperation(), &forwardSlice);
422 
423       // Look for the last TransferWriteOp in the forwardSlice of
424       // `transferRead` that operates on the same memref.
425       vector::TransferWriteOp transferWrite;
426       for (auto *sliceOp : llvm::reverse(forwardSlice)) {
427         auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
428         if (!candidateWrite || candidateWrite.source() != transferRead.source())
429           continue;
430         transferWrite = candidateWrite;
431       }
432 
433       // All operands of the TransferRead must be defined outside of the loop.
434       for (auto operand : transferRead.getOperands())
435         if (!loop.isDefinedOutsideOfLoop(operand))
436           return WalkResult::advance();
437 
438       // Only hoist transfer_read / transfer_write pairs for now.
439       if (!transferWrite)
440         return WalkResult::advance();
441 
442       LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
443                         << "\n");
444 
445       // Approximate aliasing by checking that:
446       //   1. indices are the same,
447       //   2. no other operations in the loop access the same memref except
448       //      for transfer_read/transfer_write accessing statically disjoint
449       //      slices.
450       if (transferRead.indices() != transferWrite.indices() &&
451           transferRead.getVectorType() == transferWrite.getVectorType())
452         return WalkResult::advance();
453 
454       // TODO: may want to memoize this information for performance but it
455       // likely gets invalidated often.
456       DominanceInfo dom(loop);
457       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
458         return WalkResult::advance();
459       for (auto &use : transferRead.source().getUses()) {
460         if (!loop->isAncestor(use.getOwner()))
461           continue;
462         if (use.getOwner() == transferRead.getOperation() ||
463             use.getOwner() == transferWrite.getOperation())
464           continue;
465         if (auto transferWriteUse =
466                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
467           if (!vector::isDisjointTransferSet(
468                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
469                   cast<VectorTransferOpInterface>(
470                       transferWriteUse.getOperation())))
471             return WalkResult::advance();
472         } else if (auto transferReadUse =
473                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
474           if (!vector::isDisjointTransferSet(
475                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
476                   cast<VectorTransferOpInterface>(
477                       transferReadUse.getOperation())))
478             return WalkResult::advance();
479         } else {
480           // Unknown use, we cannot prove that it doesn't alias with the
481           // transferRead/transferWrite operations.
482           return WalkResult::advance();
483         }
484       }
485 
486       // Hoist read before.
487       if (failed(loop.moveOutOfLoop({transferRead})))
488         llvm_unreachable(
489             "Unexpected failure to move transfer read out of loop");
490 
491       // Hoist write after.
492       transferWrite->moveAfter(loop);
493 
494       // Rewrite `loop` with new yields by cloning and erase the original loop.
495       OpBuilder b(transferRead);
496       auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
497                                          transferWrite.vector());
498 
499       // Transfer write has been hoisted, need to update the written value to
500       // the value yielded by the newForOp.
501       transferWrite.vector().replaceAllUsesWith(
502           newForOp.getResults().take_back()[0]);
503 
504       changed = true;
505       loop.erase();
506       // Need to interrupt and restart because erasing the loop messes up the
507       // walk.
508       return WalkResult::interrupt();
509     });
510   }
511 }
512