1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Dominance.h"
16 #include "mlir/IR/Operation.h"
17 
18 using namespace mlir;
19 using namespace linalg;
20 using namespace mlir::bufferization;
21 
22 namespace {
23 
24 // TODO: Ops in the linalg dialect can directly implement this interface.
25 
26 /// Generic conversion for any LinalgOp on tensors.
27 static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
28                                        const BufferizationState &state) {
29   // Take a guard before anything else.
30   OpBuilder::InsertionGuard g(rewriter);
31   rewriter.setInsertionPoint(op);
32 
33   // Nothing to do. This op is already bufferized.
34   if (op.hasBufferSemantics())
35     return success();
36 
37   // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
38   // basis.
39   if (!op.hasTensorSemantics())
40     return op->emitError() << "op does not have tensor semantics";
41 
42   // New input operands for the cloned op.
43   SmallVector<Value> newInputBuffers;
44   newInputBuffers.reserve(op.getNumInputs());
45   for (OpOperand *opOperand : op.getInputOperands()) {
46     if (op.isScalar(opOperand)) {
47       newInputBuffers.push_back(opOperand->get());
48       continue;
49     }
50     // Input operands are never written to.
51     newInputBuffers.push_back(
52         *state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true));
53   }
54 
55   // New output operands for the cloned op.
56   SmallVector<Value> newOutputBuffers;
57   for (OpResult opResult : op->getOpResults()) {
58     SmallVector<OpOperand *> aliasingOpOperands =
59         state.getAliasingOpOperand(opResult);
60     assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand");
61     FailureOr<Value> resultBuffer =
62         state.getBuffer(rewriter, *aliasingOpOperands.front());
63     if (failed(resultBuffer))
64       return failure();
65     newOutputBuffers.push_back(*resultBuffer);
66   }
67 
68   // Merge input/output operands.
69   SmallVector<Value> newOperands = newInputBuffers;
70   newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
71 
72   // Set insertion point now that potential alloc/dealloc are introduced.
73   rewriter.setInsertionPoint(op);
74   // Clone the op, but use the new operands. Move the existing block into the
75   // new op. Since the new op does not have any tensor results, it does not
76   // return anything.
77   assert(op->getNumRegions() == 1 && "expected that op has 1 region");
78   auto newOp = cast<LinalgOp>(op.cloneWithoutRegions(
79       rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
80   rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0),
81                               newOp->getRegion(0).begin());
82 
83   // Replace the results of the old op with the new output buffers.
84   replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
85 
86   return success();
87 }
88 
89 /// Linalg OpResults usually bufferize inplace with their tied (output
90 /// OpOperands. However, if an output OpOperand is not used in the computation,
91 /// it is better to bufferize inplace with an actually used input OpOperand;
92 /// less memory will be touched that way.
93 ///
94 /// Example:
95 /// O(i, j) = A(i, j) + B(j)  --> bufferizes inplace to:  A(i, j) += B(j)
96 ///
97 /// O(i, j) = A(j, i) + B(j)  --> cannot bufferize inplace with A because
98 ///                               indexing maps are not identical
99 ///
100 /// O(i, j) += A(i, j) + B(j) --> Output is used in computation.
101 /// This could bufferize inplace with A:
102 /// A(i, j) += O(i, j) + B(j)
103 /// However, we choose to bufferize inplace with O here, as there is no clear
104 /// benefit of choosing A. TODO: We may want to consider both options and make
105 /// an informed decision during analysis in the future.
106 static DenseMap<OpOperand *, OpResult> computeAliasingPairs(LinalgOp op) {
107   DenseMap<OpOperand *, OpResult> mapping;
108   for (OpResult opResult : op->getOpResults()) {
109     OpOperand *tiedOperand =
110         op.getOutputTensorOperands()[opResult.getResultNumber()];
111     AffineMap outputIndexingMap = op.getTiedIndexingMap(tiedOperand);
112     bool onlyParallelIterators = op.getNumParallelLoops() == op.getNumLoops();
113     bool tiedOperandUsed = op.payloadUsesValueFromOperand(tiedOperand);
114 
115     // If the output arg is used in the computation or at least one iterator is
116     // not parallel, try to bufferize inplace with the corresponding output
117     // tensor.
118     if (tiedOperandUsed || !onlyParallelIterators) {
119       mapping[tiedOperand] = opResult;
120       continue;
121     }
122 
123     // Otherwise, try to bufferize inplace with one of the inputs.
124     OpOperand *chosenOperand = nullptr;
125     for (OpOperand *opOperand : op.getInputTensorOperands()) {
126       if (opOperand->get().getType() != opResult.getType())
127         continue;
128       if (!op.payloadUsesValueFromOperand(opOperand))
129         continue;
130       if (op.getTiedIndexingMap(opOperand) != outputIndexingMap)
131         continue;
132       // No other OpResult bufferizes aliases with this OpOperand.
133       if (mapping.count(opOperand))
134         continue;
135       assert(op.getTiedIndexingMap(opOperand).isProjectedPermutation() &&
136              "expected projected permutation");
137       chosenOperand = opOperand;
138       break;
139     }
140 
141     // No suitable input tensor found. Use output tensor.
142     // TODO: This operand could bufferize inplace with OpOperands that have the
143     // correct type, even if they are not used inside the computation.
144     if (!chosenOperand)
145       chosenOperand = tiedOperand;
146 
147     mapping[chosenOperand] = opResult;
148   }
149   return mapping;
150 }
151 
152 /// Bufferization of linalg.generic. Replace with a new linalg.generic that
153 /// operates entirely on memrefs.
154 template <typename OpTy>
155 struct LinalgOpInterface
156     : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
157                                                     OpTy> {
158   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
159                               const BufferizationState &state) const {
160     // Operand is read if it is used in the computation.
161     auto genericOp = cast<linalg::LinalgOp>(op);
162     return genericOp.payloadUsesValueFromOperand(&opOperand);
163   }
164 
165   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
166                                const BufferizationState &state) const {
167     // Operand is written to if it has an aliasing OpResult.
168     auto bufferizableOp = cast<BufferizableOpInterface>(op);
169     return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
170   }
171 
172   SmallVector<OpOperand *>
173   getAliasingOpOperand(Operation *op, OpResult opResult,
174                        const BufferizationState &state) const {
175     auto genericOp = cast<linalg::LinalgOp>(op);
176 
177     // By default, the i-th OpResult may alias with the i-th "out" tensor.
178     if (state.getOptions().alwaysAliasingWithDest)
179       return {genericOp.getOutputOperand(opResult.getResultNumber())};
180 
181     // We can try to be smart and alias in-place with an "in" tensor if the
182     // corresponding "out" tensor is not used in the computation.
183     // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`.
184     DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
185     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands())
186       if (pairs[opOperand] == opResult)
187         return {opOperand};
188     return {};
189   }
190 
191   SmallVector<OpResult>
192   getAliasingOpResult(Operation *op, OpOperand &opOperand,
193                       const BufferizationState &state) const {
194     auto genericOp = cast<linalg::LinalgOp>(op);
195 
196     // By default, the i-th "out" tensor may alias with the i-th OpResult.
197     if (state.getOptions().alwaysAliasingWithDest) {
198       if (genericOp.isOutputTensor(&opOperand))
199         return {genericOp.getTiedOpResult(&opOperand)};
200       return {};
201     }
202 
203     // We can try to be smart. See comment in `getAliasingOpOperand`.
204     // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`.
205     DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
206     if (!pairs.count(&opOperand))
207       return {};
208     return {pairs[&opOperand]};
209   }
210 
211   BufferRelation bufferRelation(Operation *op, OpResult opResult,
212                                 const BufferizationState &state) const {
213     return BufferRelation::Equivalent;
214   }
215 
216   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
217                           const BufferizationState &state) const {
218     return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
219   }
220 };
221 
222 struct InitTensorOpInterface
223     : public BufferizableOpInterface::ExternalModel<InitTensorOpInterface,
224                                                     linalg::InitTensorOp> {
225   bool isMemoryWrite(Operation *op, OpResult opResult,
226                      const BufferizationState &state) const {
227     // InitTensorOps allocate but do not write.
228     return false;
229   }
230 
231   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
232                           const BufferizationState &state) const {
233     auto initTensorOp = cast<linalg::InitTensorOp>(op);
234 
235     // The InitTensorOp may have been eliminated.
236     if (initTensorOp->getUses().empty())
237       return success();
238 
239     FailureOr<Value> alloc =
240         createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(),
241                     state.getOptions().createDeallocs, state.getOptions());
242     if (failed(alloc))
243       return failure();
244     replaceOpWithBufferizedValues(rewriter, op, *alloc);
245     return success();
246   }
247 };
248 
249 /// Bufferization of linalg.tiled_loop. Replace with a new linalg.tiled_loop
250 /// that operates entirely on memrefs.
251 struct TiledLoopOpInterface
252     : public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface,
253                                                     linalg::TiledLoopOp> {
254   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
255                               const BufferizationState &state) const {
256     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
257 
258     // linalg.tiled_loop operands alone do not bufferize to a memory read, but
259     // one of the uses of their matching bbArgs may.
260     return state.isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand));
261   }
262 
263   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
264                                const BufferizationState &state) const {
265     auto bufferizableOp = cast<BufferizableOpInterface>(op);
266 
267     // Only operands with an aliasing OpResult (i.e., output operands) bufferize
268     // to a memory write.
269     return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
270   }
271 
272   SmallVector<OpResult>
273   getAliasingOpResult(Operation *op, OpOperand &opOperand,
274                       const BufferizationState &state) const {
275     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
276 
277     // Output operands are tied to their corresponding OpResults.
278     OpResult opResult = tiledLoopOp.getTiedOpResult(opOperand);
279     if (!opResult)
280       return {};
281     return {opResult};
282   }
283 
284   BufferRelation bufferRelation(Operation *op, OpResult opResult,
285                                 const BufferizationState &state) const {
286     return BufferRelation::Equivalent;
287   }
288 
289   bool isWritable(Operation *op, Value value,
290                   const BufferizationState &state) const {
291     // Interestingly, linalg::TiledLoopOp's bbArgs can **always** be viewed
292     // inplace from the perspective of nested ops:
293     //   1. Either the matching iter operand is not bufferized inplace and an
294     //      alloc + optional copy makes the bbArg itself inplaceable.
295     //   2. Or the matching iter operand is bufferized inplace and bbArg just
296     //      bufferizes to that too.
297     return true;
298   }
299 
300   bool isAllocationHoistingBarrier(Operation *op) const { return true; }
301 
302   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
303                           const BufferizationState &state) const {
304     auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
305 
306     // Compute new inputs, outputs and results.
307     SmallVector<Value> newInputs, newOutputs, newResults;
308     for (unsigned i = tiledLoopOp.getNumControlOperands();
309          i < tiledLoopOp->getNumOperands(); ++i) {
310       OpOperand &operand = tiledLoopOp->getOpOperand(i);
311       Value rewrittenValue = operand.get();
312       if (rewrittenValue.getType().isa<TensorType>()) {
313         FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, operand);
314         if (failed(bufferOrFailure))
315           return failure();
316         rewrittenValue = *bufferOrFailure;
317       }
318       if (i <
319           tiledLoopOp.getNumControlOperands() + tiledLoopOp.getNumInputs()) {
320         newInputs.push_back(rewrittenValue);
321       } else {
322         newOutputs.push_back(rewrittenValue);
323         if (operand.get().getType().isa<TensorType>())
324           newResults.push_back(rewrittenValue);
325       }
326     }
327 
328     // Create new TiledLoopOp.
329     auto newTiledLoopOp = rewriter.create<TiledLoopOp>(
330         tiledLoopOp.getLoc(), tiledLoopOp.lowerBound(),
331         tiledLoopOp.upperBound(), tiledLoopOp.step(), newInputs, newOutputs,
332         tiledLoopOp.iterator_types(), tiledLoopOp.distribution_types());
333 
334     // Remove terminator.
335     if (!newTiledLoopOp.getBody()->empty())
336       rewriter.eraseOp(tiledLoopOp.getBody()->getTerminator());
337 
338     // Compute new loop body arguments.
339     SmallVector<Value> newBlockArgs, newRegionInOutArgs, oldRegionInOutArgs;
340     ValueRange newInductionVars = newTiledLoopOp.getInductionVars();
341     newBlockArgs.append(newInductionVars.begin(), newInductionVars.end());
342 
343     ValueRange newRegionInArgs = newTiledLoopOp.getRegionInputArgs();
344     ValueRange newRegionOutArgs = newTiledLoopOp.getRegionOutputArgs();
345     newRegionInOutArgs.append(newRegionInArgs.begin(), newRegionInArgs.end());
346     newRegionInOutArgs.append(newRegionOutArgs.begin(), newRegionOutArgs.end());
347 
348     ValueRange oldRegionInArgs = tiledLoopOp.getRegionInputArgs();
349     ValueRange oldRegionOutArgs = tiledLoopOp.getRegionOutputArgs();
350     oldRegionInOutArgs.append(oldRegionInArgs.begin(), oldRegionInArgs.end());
351     oldRegionInOutArgs.append(oldRegionOutArgs.begin(), oldRegionOutArgs.end());
352     assert(newRegionInArgs.size() == oldRegionInArgs.size() &&
353            "expected same number of input args");
354     assert(newRegionOutArgs.size() == oldRegionOutArgs.size() &&
355            "expected same number of output args");
356 
357     for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) {
358       Value oldArg = std::get<0>(it);
359       Value newArg = std::get<1>(it);
360       rewriter.setInsertionPointToStart(newTiledLoopOp.getBody());
361       if (oldArg.getType().isa<TensorType>()) {
362         newBlockArgs.push_back(rewriter.create<bufferization::ToTensorOp>(
363             oldArg.getLoc(), newArg));
364       } else {
365         newBlockArgs.push_back(newArg);
366       }
367     }
368 
369     // Move old body into new loop.
370     rewriter.mergeBlocks(tiledLoopOp.getBody(), newTiledLoopOp.getBody(),
371                          newBlockArgs);
372 
373     // Replace previous terminator with a new one that does not yield anything.
374     auto oldTerminator =
375         cast<linalg::YieldOp>(newTiledLoopOp.getBody()->getTerminator());
376     rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody());
377     auto newTerminator =
378         rewriter.create<linalg::YieldOp>(oldTerminator->getLoc());
379 
380     // Copy buffer of yielded tensor to output buffer. If everything bufferized
381     // inplace, this copy will fold away.
382     rewriter.setInsertionPoint(newTerminator);
383     for (auto it : llvm::zip(oldTerminator.values(), newOutputs)) {
384       Value output = std::get<1>(it);
385       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
386           newTerminator.getLoc(), output.getType(), std::get<0>(it));
387       if (failed(createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp,
388                               output, state.getOptions())))
389         return failure();
390     }
391 
392     // Erase old terminator.
393     rewriter.eraseOp(oldTerminator);
394 
395     // Replace results and delete old op.
396     replaceOpWithBufferizedValues(rewriter, op, newResults);
397 
398     return success();
399   }
400 };
401 
402 /// Bufferization of linalg.yield. Bufferized as part of linalg.tiled_loop's
403 /// bufferization.
404 struct YieldOpInterface
405     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
406                                                     linalg::YieldOp> {
407   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
408                               const BufferizationState &state) const {
409     return true;
410   }
411 
412   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
413                                const BufferizationState &state) const {
414     return false;
415   }
416 
417   SmallVector<OpResult>
418   getAliasingOpResult(Operation *op, OpOperand &opOperand,
419                       const BufferizationState &state) const {
420     return {};
421   }
422 
423   bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
424                             const BufferizationState &state) const {
425     // Yield operands always bufferize inplace. Otherwise, an alloc + copy
426     // may be generated inside the block. We should not return/yield allocations
427     // when possible.
428     return true;
429   }
430 
431   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
432                           const BufferizationState &state) const {
433     auto yieldOp = cast<linalg::YieldOp>(op);
434 
435     if (!yieldOp->getParentOfType<TiledLoopOp>())
436       return yieldOp->emitError(
437           "expected that linalg.yield terminates a tiled_loop");
438 
439     assert(yieldOp->getOpOperands().empty() &&
440            "expected that linalg.yield was bufferized together with"
441            " tiled_loop");
442     return success();
443   }
444 };
445 
446 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers
447 /// the `BufferizableOpInterface` with each of them.
448 template <typename... OpTys>
449 struct LinalgOpInterfaceHelper;
450 
451 template <typename First, typename... Others>
452 struct LinalgOpInterfaceHelper<First, Others...> {
453   static void registerOpInterface(DialectRegistry &registry) {
454     registry.addOpInterface<First, LinalgOpInterface<First>>();
455     LinalgOpInterfaceHelper<Others...>::registerOpInterface(registry);
456   }
457 };
458 
459 template <>
460 struct LinalgOpInterfaceHelper<> {
461   static void registerOpInterface(DialectRegistry &registry) {}
462 };
463 
464 } // namespace
465 
466 /// Return true if all `neededValues` are in scope at the given
467 /// `insertionPoint`.
468 static bool
469 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
470                                    Operation *insertionPoint,
471                                    const SmallVector<Value> &neededValues) {
472   for (Value val : neededValues) {
473     if (auto bbArg = val.dyn_cast<BlockArgument>()) {
474       Block *owner = bbArg.getOwner();
475       if (!owner->findAncestorOpInBlock(*insertionPoint))
476         return false;
477     } else {
478       auto opResult = val.cast<OpResult>();
479       if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
480         return false;
481     }
482   }
483   return true;
484 }
485 
486 /// Return true if the given `insertionPoint` dominates all uses of
487 /// `initTensorOp`.
488 static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
489                                         Operation *insertionPoint,
490                                         Operation *initTensorOp) {
491   for (Operation *user : initTensorOp->getUsers())
492     if (!domInfo.dominates(insertionPoint, user))
493       return false;
494   return true;
495 }
496 
497 /// Find a valid insertion point for a replacement of `initTensorOp`, assuming
498 /// that the replacement may use any value from `neededValues`.
499 static Operation *
500 findValidInsertionPoint(Operation *initTensorOp,
501                         const SmallVector<Value> &neededValues) {
502   DominanceInfo domInfo;
503 
504   // Gather all possible insertion points: the location of `initTensorOp` and
505   // right after the definition of each value in `neededValues`.
506   SmallVector<Operation *> insertionPointCandidates;
507   insertionPointCandidates.push_back(initTensorOp);
508   for (Value val : neededValues) {
509     // Note: The anchor op is using all of `neededValues`, so:
510     // * in case of a block argument: There must be at least one op in the block
511     //                                (the anchor op or one of its parents).
512     // * in case of an OpResult: There must be at least one op right after the
513     //                           defining op (the anchor op or one of its
514     //                           parents).
515     if (auto bbArg = val.dyn_cast<BlockArgument>()) {
516       insertionPointCandidates.push_back(
517           &bbArg.getOwner()->getOperations().front());
518     } else {
519       insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
520     }
521   }
522 
523   // Select first matching insertion point.
524   for (Operation *insertionPoint : insertionPointCandidates) {
525     // Check if all needed values are in scope.
526     if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
527                                             neededValues))
528       continue;
529     // Check if the insertion point is before all uses.
530     if (!insertionPointDominatesUses(domInfo, insertionPoint, initTensorOp))
531       continue;
532     return insertionPoint;
533   }
534 
535   // No suitable insertion point was found.
536   return nullptr;
537 }
538 
539 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced
540 /// with the the result of `rewriteFunc` if it is anchored on a matching
541 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
542 /// chain, starting from the OpOperand and always following the aliasing
543 /// OpOperand, that eventually ends at a single InitTensorOp.
544 LogicalResult mlir::linalg::eliminateInitTensors(
545     Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
546     AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
547     SmallVector<Operation *> &newOps) {
548   OpBuilder b(op->getContext());
549 
550   WalkResult status = op->walk([&](Operation *op) {
551     for (OpOperand &operand : op->getOpOperands()) {
552       // Skip operands that do not bufferize inplace.
553       if (!aliasInfo.isInPlace(operand))
554         continue;
555       // All values that are needed to create the replacement op.
556       SmallVector<Value> neededValues;
557       // Is this a matching OpOperand?
558       if (!anchorMatchFunc(operand, neededValues))
559         continue;
560       SetVector<Value> maybeInitTensor =
561           state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
562             // Continue traversal until this function returns true.
563             OpResult opResult = val.dyn_cast<OpResult>();
564             if (!opResult)
565               return true;
566             SmallVector<OpOperand *> opOperands =
567                 state.getAliasingOpOperand(opResult);
568             if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
569                   return aliasInfo.isInPlace(*operand);
570                 }))
571               return true;
572             // Only equivalent tensors are supported at the moment.
573             // TODO: Support cases such as extract_slice(init_tensor)
574             return !llvm::all_of(opOperands, [&](OpOperand *operand) {
575               return aliasInfo.areEquivalentBufferizedValues(operand->get(),
576                                                              opResult);
577             });
578           });
579 
580       // Replace only if the reverse use-def chain ends at exactly one
581       // InitTensorOp.
582       if (maybeInitTensor.size() != 1 ||
583           !maybeInitTensor.front().getDefiningOp<InitTensorOp>())
584         return WalkResult::skip();
585       Value initTensor = maybeInitTensor.front();
586 
587       // Find a suitable insertion point.
588       Operation *insertionPoint =
589           findValidInsertionPoint(initTensor.getDefiningOp(), neededValues);
590       if (!insertionPoint)
591         continue;
592 
593       // Create a replacement for the InitTensorOp.
594       b.setInsertionPoint(insertionPoint);
595       Value replacement = rewriteFunc(b, initTensor.getLoc(), operand);
596       if (!replacement)
597         continue;
598 
599       // Uses of the InitTensorOp are replaced here, but the op is not deleted.
600       // InitTensorOps without uses are ignored by the bufferization.
601       initTensor.replaceAllUsesWith(replacement);
602       aliasInfo.createAliasInfoEntry(replacement);
603       aliasInfo.unionAliasSets(initTensor, replacement);
604       aliasInfo.unionEquivalenceClasses(initTensor, replacement);
605 
606       // Register replacement ops.
607       if (Operation *newOp = replacement.getDefiningOp())
608         newOps.push_back(newOp);
609     }
610 
611     // Advance to the next operation.
612     return WalkResult::advance();
613   });
614 
615   return failure(status.wasInterrupted());
616 }
617 
618 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp can be
619 /// eliminated if it is eventually inserted into another tensor (and some other
620 /// conditions are met).
621 ///
622 /// E.g.:
623 /// %0 = linalg.init_tensor
624 /// %1 = linalg.fill(%cst, %0) {inplace = [true]}
625 /// %2 = tensor.insert_slice %1 into %t[10][20][1]
626 ///
627 /// InitTensorOp elimination will try to fill %t inplace instead of filling a
628 /// new allocation %0 and inserting it into %t. This is done by replacing the
629 /// InitTensorOp with:
630 ///
631 /// %0 = tensor.extract_slice %t[10][20][1]
632 ///
633 /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
634 /// those bufferize inplace in the absence of other conflicts.
635 ///
636 /// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert
637 /// source's reverse use-def chain is eliminated if:
638 /// * The InsertSliceOp was decided to bufferize inplace.
639 /// * On the reverse use-def chain path from the InsertSliceOp to the
640 ///   InitTensorOp, all ops were decided to bufferize inplace and the buffer
641 ///   relation is "equivalent" (TODO: can be relaxed if needed).
642 /// * The reverse use-def chain has exactly one end, which is the InitTensorOp.
643 ///
644 /// Note that the newly inserted ExtractSliceOp may have to bufferize
645 /// out-of-place due to RaW conflicts.
646 LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep(
647     Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
648     SmallVector<Operation *> &newOps) {
649   return eliminateInitTensors(
650       op, state, aliasInfo,
651       /*anchorMatchFunc=*/
652       [&](OpOperand &operand, SmallVector<Value> &neededValues) {
653         auto insertSliceOp =
654             dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
655         if (!insertSliceOp)
656           return false;
657         // Only inplace bufferized InsertSliceOps are eligible.
658         if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/))
659           return false;
660         if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
661           return false;
662 
663         // Collect all values that are needed to construct the replacement op.
664         neededValues.append(insertSliceOp.offsets().begin(),
665                             insertSliceOp.offsets().end());
666         neededValues.append(insertSliceOp.sizes().begin(),
667                             insertSliceOp.sizes().end());
668         neededValues.append(insertSliceOp.strides().begin(),
669                             insertSliceOp.strides().end());
670         neededValues.push_back(insertSliceOp.dest());
671 
672         return true;
673       },
674       /*rewriteFunc=*/
675       [](OpBuilder &b, Location loc, OpOperand &operand) {
676         auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
677         // Expand offsets, sizes and strides to the full rank to handle the
678         // rank-reducing case.
679         SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets();
680         SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes();
681         SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides();
682         OffsetSizeAndStrideOpInterface::expandToRank(
683             insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides,
684             [&](Value target, int64_t dim) -> OpFoldResult {
685               auto shapedType = target.getType().cast<ShapedType>();
686               if (shapedType.isDynamicDim(dim))
687                 return b.create<tensor::DimOp>(loc, target, dim).result();
688               return b.getIndexAttr(shapedType.getDimSize(dim));
689             });
690         auto t = tensor::ExtractSliceOp::inferRankReducedResultType(
691             insertOp.getSourceType().getRank(),
692             insertOp.dest().getType().cast<RankedTensorType>(), mixedOffsets,
693             mixedSizes, mixedStrides);
694         auto extractOp = b.create<tensor::ExtractSliceOp>(
695             loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides);
696         return extractOp.result();
697       },
698       newOps);
699 }
700 
701 void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
702     DialectRegistry &registry) {
703   registry.addOpInterface<linalg::InitTensorOp, InitTensorOpInterface>();
704   registry.addOpInterface<linalg::TiledLoopOp, TiledLoopOpInterface>();
705   registry.addOpInterface<linalg::YieldOp, YieldOpInterface>();
706 
707   // Register all Linalg structured ops. `LinalgOp` is an interface and it is
708   // not possible to attach an external interface to an existing interface.
709   // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
710   LinalgOpInterfaceHelper<
711 #define GET_OP_LIST
712 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
713       >::registerOpInterface(registry);
714 }
715