1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/AsmState.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/IR/Value.h"
19 #include "llvm/Support/Debug.h"
20 
21 namespace mlir {
22 namespace bufferization {
23 
24 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
25 
26 } // namespace bufferization
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "bufferizable-op-interface"
30 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
31 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
32 
33 using namespace mlir;
34 using namespace bufferization;
35 
36 /// Attribute name used to mark region arguments that can be bufferized
37 /// in-place during linalg comprehensive bufferization.
38 constexpr const ::llvm::StringLiteral
39     bufferization::BufferizableOpInterface::kInplaceableAttrName;
40 
41 /// Attribute name used to mark allocs that are created by the bufferization.
42 static const char *kBufferAllocationAttr = "bufferization.allocation";
43 
44 /// Attribute name used to mark allocs that should not be deallocated.
45 static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
46 
47 //===----------------------------------------------------------------------===//
48 // BufferizationOptions
49 //===----------------------------------------------------------------------===//
50 
51 // Default constructor for BufferizationOptions.
52 BufferizationOptions::BufferizationOptions() = default;
53 
54 bool BufferizationOptions::isOpAllowed(Operation *op) const {
55   // Special case: If function boundary bufferization is deactivated, do not
56   // allow ops that belong to the `func` dialect.
57   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
58   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
59     return false;
60 
61   // All other ops: Allow/disallow according to filter.
62   bool isAllowed = !filterHasAllowRule();
63   for (const OpFilterEntry &entry : opFilter) {
64     bool filterResult = entry.fn(op);
65     switch (entry.type) {
66     case OpFilterEntry::ALLOW:
67       isAllowed |= filterResult;
68       break;
69     case OpFilterEntry::DENY:
70       if (filterResult)
71         // DENY filter matches. This op is no allowed. (Even if other ALLOW
72         // filters may match.)
73         return false;
74     };
75   }
76   return isAllowed;
77 }
78 
79 BufferizableOpInterface
80 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
81   if (isOpAllowed(op))
82     return dyn_cast<BufferizableOpInterface>(op);
83   return nullptr;
84 }
85 
86 BufferizableOpInterface
87 BufferizationOptions::dynCastBufferizableOp(Value value) const {
88   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
89     if (isOpAllowed(bufferizableOp.getOperation()))
90       return bufferizableOp;
91   return nullptr;
92 }
93 
94 void BufferizationOptions::addDialectStateInitializer(
95     StringRef name, const DialectStateInitFn &fn) {
96   stateInitializers.push_back(
97       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // Helper functions for BufferizableOpInterface
102 //===----------------------------------------------------------------------===//
103 
104 static void setInsertionPointAfter(OpBuilder &b, Value value) {
105   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
106     b.setInsertionPointToStart(bbArg.getOwner());
107   } else {
108     b.setInsertionPointAfter(value.getDefiningOp());
109   }
110 }
111 
112 /// Determine which OpOperand* will alias with `result` if the op is bufferized
113 /// in place. Return an empty vector if the op is not bufferizable.
114 SmallVector<OpOperand *>
115 AnalysisState::getAliasingOpOperand(OpResult result) const {
116   if (Operation *op = result.getDefiningOp())
117     if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
118       return bufferizableOp.getAliasingOpOperand(result, *this);
119   return {};
120 }
121 
122 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
123 /// in place. Return an empty vector if the op is not bufferizable.
124 SmallVector<OpResult>
125 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
126   if (auto bufferizableOp =
127           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
128     return bufferizableOp.getAliasingOpResult(opOperand, *this);
129   return {};
130 }
131 
132 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
133 /// op is not bufferizable.
134 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
135   if (auto bufferizableOp =
136           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
137     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
138 
139   // Unknown op that returns a tensor. The inplace analysis does not support it.
140   // Conservatively return true.
141   return true;
142 }
143 
144 /// Return true if `opOperand` bufferizes to a memory write. Return
145 /// `true` if the op is not bufferizable.
146 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
147   if (auto bufferizableOp =
148           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
149     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
150 
151   // Unknown op that returns a tensor. The inplace analysis does not support it.
152   // Conservatively return true.
153   return true;
154 }
155 
156 /// Return true if `opOperand` does neither read nor write but bufferizes to an
157 /// alias. Return false if the op is not bufferizable.
158 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
159   if (auto bufferizableOp =
160           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
161     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
162 
163   // Unknown op that returns a tensor. The inplace analysis does not support it.
164   // Conservatively return false.
165   return false;
166 }
167 
168 /// Return true if the given value is read by an op that bufferizes to a memory
169 /// read. Also takes into account ops that create an alias but do not read by
170 /// themselves (e.g., ExtractSliceOp).
171 bool AnalysisState::isValueRead(Value value) const {
172   assert(value.getType().isa<TensorType>() && "expected TensorType");
173   SmallVector<OpOperand *> workingSet;
174   for (OpOperand &use : value.getUses())
175     workingSet.push_back(&use);
176 
177   while (!workingSet.empty()) {
178     OpOperand *uMaybeReading = workingSet.pop_back_val();
179     // Skip over all ops that neither read nor write (but create an alias).
180     if (bufferizesToAliasOnly(*uMaybeReading))
181       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
182         for (OpOperand &use : opResult.getUses())
183           workingSet.push_back(&use);
184     if (bufferizesToMemoryRead(*uMaybeReading))
185       return true;
186   }
187 
188   return false;
189 }
190 
191 // Starting from `value`, follow the use-def chain in reverse, always selecting
192 // the aliasing OpOperands. Find and return Values for which `condition`
193 // evaluates to true. OpOperands of such matching Values are not traversed any
194 // further.
195 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
196     Value value, llvm::function_ref<bool(Value)> condition) const {
197   llvm::SetVector<Value> result, workingSet;
198   workingSet.insert(value);
199 
200   while (!workingSet.empty()) {
201     Value value = workingSet.pop_back_val();
202     if (condition(value) || value.isa<BlockArgument>()) {
203       result.insert(value);
204       continue;
205     }
206 
207     OpResult opResult = value.cast<OpResult>();
208     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
209     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
210       result.insert(value);
211       continue;
212     }
213 
214     for (OpOperand *o : opOperands)
215       workingSet.insert(o->get());
216   }
217 
218   return result;
219 }
220 
221 // Find the Values of the last preceding write of a given Value.
222 llvm::SetVector<Value>
223 AnalysisState::findLastPrecedingWrite(Value value) const {
224   return findValueInReverseUseDefChain(value, [&](Value value) {
225     Operation *op = value.getDefiningOp();
226     if (!op)
227       return true;
228     auto bufferizableOp = options.dynCastBufferizableOp(op);
229     if (!bufferizableOp)
230       return true;
231     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
232   });
233 }
234 
235 AnalysisState::AnalysisState(const BufferizationOptions &options)
236     : options(options) {
237   for (const BufferizationOptions::AnalysisStateInitFn &fn :
238        options.stateInitializers)
239     fn(*this);
240 }
241 
242 // bufferization.to_memref is not allowed to change the rank.
243 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
244 #ifndef NDEBUG
245   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
246   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
247                                    rankedTensorType.getRank()) &&
248          "to_memref would be invalid: mismatching ranks");
249 #endif
250 }
251 
252 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
253                                         const BufferizationOptions &options) {
254   auto tensorType = tensor.getType().dyn_cast<TensorType>();
255   assert(tensorType && "unexpected non-tensor type");
256 
257   // Replace "%t = to_tensor %m" with %m.
258   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
259     return toTensorOp.memref();
260 
261   // Insert to_memref op.
262   OpBuilder::InsertionGuard g(rewriter);
263   setInsertionPointAfter(rewriter, tensor);
264   Type memrefType = getMemRefType(tensorType, options);
265   ensureToMemrefOpIsValid(tensor, memrefType);
266   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
267                                                     tensor);
268 }
269 
270 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
271 /// a new buffer and copy over data from the existing buffer if out-of-place
272 /// bufferization was decided.
273 FailureOr<Value>
274 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
275                               Optional<ForceInPlacability> overrideInPlace,
276                               Optional<Operation *> customCopyInsertionPoint) {
277   const BufferizationOptions &options = analysisState.getOptions();
278   OpBuilder::InsertionGuard guard(rewriter);
279   Operation *op = opOperand.getOwner();
280   Location loc = op->getLoc();
281   SmallVector<OpResult> aliasingOpResults =
282       analysisState.getAliasingOpResult(opOperand);
283   Value operand = opOperand.get();
284   Value operandBuffer = lookupBuffer(rewriter, operand, options);
285 
286   // Can `operandBuffer` be used directly or do we need a copy?
287   bool inplace =
288       overrideInPlace != FORCE_OUT_OF_PLACE &&
289       (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
290   if (inplace)
291     return operandBuffer;
292 
293   // Bufferizing out-of-place: Allocate a new buffer.
294   // Move insertion point right after `operandBuffer`. That is where the
295   // allocation should be inserted (in the absence of allocation hoisting).
296   setInsertionPointAfter(rewriter, operandBuffer);
297   // Allocate the result buffer. The buffer should be deallocated if the tensor
298   // is not yielded and deallocs are enabled in general.
299   bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
300     return getAnalysisState().isTensorYielded(v);
301   });
302   FailureOr<Value> resultBuffer = createAlloc(
303       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
304   if (failed(resultBuffer))
305     return failure();
306   // Do not copy if the last preceding writes of `operand` are ops that do
307   // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
308   // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
309   // use-def chain, it returns that value, regardless of whether it is a
310   // memory write or not.
311   SetVector<Value> lastWrites = analysisState.findLastPrecedingWrite(operand);
312   if (llvm::none_of(lastWrites, [&](Value lastWrite) {
313         if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
314           return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
315                                               analysisState);
316         return true;
317       }))
318     return resultBuffer;
319   // Do not copy if the copied data is never read.
320   if (!aliasingOpResults.empty() &&
321       !analysisState.bufferizesToMemoryRead(opOperand) &&
322       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
323         return analysisState.isValueRead(opResult);
324       }))
325     return resultBuffer;
326   // Do not copy if this op does not read the data, but writes it.
327   if (analysisState.bufferizesToMemoryWrite(opOperand) &&
328       !analysisState.bufferizesToMemoryRead(opOperand))
329     return resultBuffer;
330 
331   if (customCopyInsertionPoint) {
332     rewriter.setInsertionPoint(*customCopyInsertionPoint);
333   } else {
334     // The copy happens right before the op that is bufferized.
335     rewriter.setInsertionPoint(op);
336   }
337   if (failed(
338           createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options)))
339     return failure();
340 
341   return resultBuffer;
342 }
343 
344 /// Return the buffer type for a given OpOperand (tensor) after bufferization.
345 BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const {
346   Value tensor = opOperand.get();
347   auto tensorType = tensor.getType().dyn_cast<TensorType>();
348   assert(tensorType && "unexpected non-tensor type");
349 
350   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
351     return toTensorOp.memref().getType().cast<BaseMemRefType>();
352 
353   return getMemRefType(tensorType, getOptions());
354 }
355 
356 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
357                                                   Operation *op,
358                                                   ValueRange values) {
359   assert(values.size() == op->getNumResults() &&
360          "expected one value per OpResult");
361   OpBuilder::InsertionGuard g(rewriter);
362 
363   // Replace all OpResults with the given values.
364   SmallVector<Value> replacements;
365   for (OpResult opResult : op->getOpResults()) {
366     Value replacement = values[opResult.getResultNumber()];
367     if (opResult.getType().isa<TensorType>()) {
368       // The OpResult is a tensor. Such values are replaced with memrefs during
369       // bufferization.
370       assert((replacement.getType().isa<MemRefType>() ||
371               replacement.getType().isa<UnrankedMemRefType>()) &&
372              "tensor op result should be replaced with a memref value");
373       // The existing uses of the OpResult still expect a tensor. Insert a
374       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
375       // loose all of its users and eventually DCE away.
376       rewriter.setInsertionPointAfter(op);
377       replacement = rewriter.create<bufferization::ToTensorOp>(
378           replacement.getLoc(), replacement);
379     }
380     replacements.push_back(replacement);
381   }
382 
383   rewriter.replaceOp(op, replacements);
384 }
385 
386 AlwaysCopyAnalysisState::AlwaysCopyAnalysisState(
387     const BufferizationOptions &options)
388     : AnalysisState(options) {
389   // Note: Allocations must be deallocated with a subsequent run of the buffer
390   // deallocation pass.
391   assert(!options.createDeallocs &&
392          "cannot create deallocs with AlwaysCopyBufferizationState");
393 }
394 
395 /// Return `true` if the given OpResult has been decided to bufferize inplace.
396 bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const {
397   // OpOperands that bufferize to a memory write are out-of-place, i.e., an
398   // alloc and copy is inserted.
399   return !bufferizesToMemoryWrite(opOperand);
400 }
401 
402 /// Return true if `v1` and `v2` bufferize to equivalent buffers.
403 bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
404                                                             Value v2) const {
405   // There is no analysis, so we do not know if the values are equivalent. The
406   // conservative answer is "false".
407   return false;
408 }
409 
410 /// Return true if the given tensor (or an aliasing tensor) is yielded from
411 /// the containing block. Also include all aliasing tensors in the same block.
412 bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const {
413   // There is no analysis, so conservatively answer "true".
414   return true;
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // Bufferization-specific scoped alloc/dealloc insertion support.
419 //===----------------------------------------------------------------------===//
420 
421 /// Create a memref allocation with the given type and dynamic extents.
422 static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
423                                     ValueRange dynShape,
424                                     const BufferizationOptions &options) {
425   if (options.allocationFn)
426     return (*options.allocationFn)(b, loc, type, dynShape,
427                                    options.bufferAlignment);
428 
429   // Default bufferallocation via AllocOp.
430   Value allocated = b.create<memref::AllocOp>(
431       loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
432   return allocated;
433 }
434 
435 /// Creates a memref deallocation. The given memref buffer must have been
436 /// allocated using `createAlloc`.
437 LogicalResult
438 bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
439                              const BufferizationOptions &options) {
440   if (options.deallocationFn)
441     return (*options.deallocationFn)(b, loc, allocatedBuffer);
442 
443   // Default buffer deallocation via DeallocOp.
444   b.create<memref::DeallocOp>(loc, allocatedBuffer);
445   return success();
446 }
447 
448 /// Compute the type of the `memref` to use for allocating the buffer for
449 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
450 /// dynamic dimensions in the returned `memref` type.
451 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
452                                             Value shapedValue,
453                                             SmallVectorImpl<Value> &dynShape) {
454   MemRefType allocMemRefType =
455       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
456 
457   // Compute the dynamic part of the shape.
458   bool reifiedShapes = false;
459   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
460           shapedValue.getDefiningOp())) {
461     ReifiedRankedShapedTypeDims resultDims;
462     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
463       reifiedShapes = true;
464       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
465       auto &shape = resultDims[resultValue.getResultNumber()];
466       for (const auto &dim : enumerate(allocMemRefType.getShape()))
467         if (ShapedType::isDynamic(dim.value()))
468           dynShape.push_back(shape[dim.index()]);
469     }
470   }
471 
472   if (!reifiedShapes) {
473     for (const auto &dim : enumerate(allocMemRefType.getShape()))
474       if (ShapedType::isDynamic(dim.value())) {
475         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
476                 shapedValue.getType().isa<MemRefType>()) &&
477                "expected MemRef type");
478         dynShape.push_back(
479             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
480       }
481   }
482 
483   return allocMemRefType;
484 }
485 
486 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
487                                     ValueRange dynShape, bool skipDealloc) {
488   auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
489   allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
490   if (skipDealloc)
491     allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr());
492   return allocaOp.getResult();
493 }
494 
495 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
496 /// block in case of a bbArg).
497 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
498                                                  Value shapedValue,
499                                                  Optional<bool> dealloc) {
500   // Take a guard before anything else.
501   OpBuilder::InsertionGuard g(b);
502 
503   // Compute allocation memref type.
504   assert(shapedValue.getType().isa<ShapedType>());
505   SmallVector<Value> dynShape;
506   MemRefType allocMemRefType =
507       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
508 
509   // Should be the buffer be deallocated again or should we let it leak?
510   bool skipDealloc;
511   if (dealloc) {
512     skipDealloc = !dealloc.getValue();
513   } else {
514     assert(shapedValue.getType().isa<TensorType>() &&
515            "must specify `dealloc` if non-tensor value is passed");
516     // Buffer should be not be deallocated if deallocs are generally deactivated
517     // or if the tensor is yielded from a block.
518     skipDealloc = !getOptions().createDeallocs ||
519                   getAnalysisState().isTensorYielded(shapedValue);
520   }
521 
522   // Create the buffer allocation.
523   return createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
524 }
525 
526 /// Create a memory copy between two memref buffers.
527 LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
528                                           Value from, Value to,
529                                           const BufferizationOptions &options) {
530   if (options.memCpyFn)
531     return (*options.memCpyFn)(b, loc, from, to);
532 
533   b.create<memref::CopyOp>(loc, from, to);
534   return success();
535 }
536 
537 LogicalResult
538 bufferization::createAllocDeallocOps(Operation *op,
539                                      const BufferizationOptions &options,
540                                      bool onlyLeakingAllocs, bool *changed) {
541   IRRewriter rewriter(op->getContext());
542   if (changed)
543     *changed = false;
544 
545   // Bufferization creates memref.alloca ops. After bufferization, these must be
546   // rewritten to alloc/dealloc ops as specified in the bufferization options.
547   WalkResult status = op->walk([&](memref::AllocaOp allocaOp) {
548     // Ignore memref.alloca ops that were not created by the bufferization.
549     if (!allocaOp->hasAttr(kBufferAllocationAttr))
550       return WalkResult::skip();
551     // If `onlyLeakingAllocs`, process only ops that are marked as
552     // "skip dealloc".
553     bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr);
554     if (onlyLeakingAllocs && !skipDealloc)
555       return WalkResult::skip();
556 
557     // Create alloc.
558     Block *block = allocaOp->getBlock();
559     rewriter.setInsertionPoint(allocaOp);
560     FailureOr<Value> alloc =
561         createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
562                     allocaOp.dynamicSizes(), options);
563     if (failed(alloc))
564       return WalkResult::interrupt();
565     rewriter.replaceOp(allocaOp, *alloc);
566     if (changed)
567       *changed = true;
568 
569     // Stop here if the buffer should not be deallocated.
570     if (skipDealloc)
571       return WalkResult::advance();
572 
573     // Create dealloc.
574     rewriter.setInsertionPoint(block->getTerminator());
575     if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
576       return WalkResult::interrupt();
577 
578     return WalkResult::advance();
579   });
580 
581   return success(!status.wasInterrupted());
582 }
583 
584 /// Try to hoist all new buffer allocations until the next hoisting barrier.
585 // TODO: Consolidate this function with the existing buffer hoisting pass.
586 LogicalResult
587 bufferization::hoistBufferAllocations(Operation *op,
588                                       const BufferizationOptions &options) {
589   // Nothing to do if allocation hoisting is deactivated.
590   if (!options.hoistAllocations)
591     return success();
592 
593   // Gather all buffer allocations that were created by the bufferization.
594   SmallVector<Operation *> allocaOps;
595   op->walk([&](memref::AllocaOp allocaOp) {
596     if (allocaOp->hasAttr(kBufferAllocationAttr))
597       allocaOps.push_back(allocaOp);
598   });
599 
600   for (Operation *allocaOp : allocaOps) {
601     // TODO: Hoisting of allocs with dynamic shape not implemented.
602     if (!allocaOp->getOpOperands().empty())
603       continue;
604 
605     Operation *op = allocaOp->getParentOp();
606     while (op) {
607       if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) {
608         if (bufferizableOp.isAllocationHoistingBarrier()) {
609           break;
610         }
611       } else {
612         // Op is not bufferizable: It may not be safe to hoist across this op.
613         break;
614       }
615       op = op->getParentOp();
616     }
617 
618     // FuncOp is an allocation hoisting barrier, so this should never happen.
619     assert(op && "allocation hoisting barrier not found");
620 
621     // Nothing to do if the insertion point is in the same block.
622     if (op == allocaOp->getParentOp())
623       continue;
624 
625     // `op` may have multiple blocks. Make sure that we insert in the right one.
626     SmallVector<Block *> blocks;
627     for (Region &r : op->getRegions())
628       for (Block &b : r.getBlocks())
629         blocks.push_back(&b);
630     auto *insertionBlock = llvm::find_if(
631         blocks, [&](Block *b) { return b->findAncestorOpInBlock(*allocaOp); });
632     assert(insertionBlock != blocks.end() && "owning block not found");
633 
634     // Move to the beginning of the block.
635     allocaOp->moveBefore(&(*insertionBlock)->front());
636   }
637 
638   return success();
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // Bufferization-specific BlockAndValueMapping support with debugging.
643 //===----------------------------------------------------------------------===//
644 
645 bool bufferization::isFunctionArgument(Value value) {
646   auto bbArg = value.dyn_cast<BlockArgument>();
647   if (!bbArg)
648     return false;
649   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
650 }
651 
652 MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType,
653                                                   Attribute memorySpace) {
654   MemRefLayoutAttrInterface layout = {};
655   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
656                          layout, memorySpace);
657 }
658 
659 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
660                                             const BufferizationOptions &options,
661                                             MemRefLayoutAttrInterface layout,
662                                             Attribute memorySpace) {
663   // Case 1: Unranked memref type.
664   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
665     assert(!layout && "UnrankedTensorType cannot have a layout map");
666     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
667                                    memorySpace);
668   }
669 
670   // Case 2: Ranked memref type with specified layout. If fully dynamic layout
671   // maps are not requested, generate a type with `layout`, which is empty (no
672   // layout map) by default.
673   auto rankedTensorType = tensorType.cast<RankedTensorType>();
674   if (layout || !options.fullyDynamicLayoutMaps) {
675     return MemRefType::get(rankedTensorType.getShape(),
676                            rankedTensorType.getElementType(), layout,
677                            memorySpace);
678   }
679 
680   // Case 3: Ranked memref type with unspecified layout. Choose the most dynamic
681   // one.
682   // TODO: address space decisions to connect with the actual alloc.
683   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
684   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
685                                       ShapedType::kDynamicStrideOrOffset);
686   AffineMap stridedLayout = makeStridedLinearLayoutMap(
687       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
688   return MemRefType::get(rankedTensorType.getShape(),
689                          rankedTensorType.getElementType(), stridedLayout,
690                          memorySpace);
691 }
692