1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/IR/AsmState.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/IR/Value.h"
19 #include "llvm/Support/Debug.h"
20 
21 namespace mlir {
22 namespace bufferization {
23 
24 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
25 
26 } // namespace bufferization
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "bufferizable-op-interface"
30 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
31 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
32 
33 using namespace mlir;
34 using namespace bufferization;
35 
36 /// Attribute name used to mark region arguments that can be bufferized
37 /// in-place during linalg comprehensive bufferization.
38 constexpr const ::llvm::StringLiteral
39     bufferization::BufferizableOpInterface::kInplaceableAttrName;
40 
41 /// Attribute name used to mark allocs that are created by the bufferization.
42 static const char *kBufferAllocationAttr = "bufferization.allocation";
43 
44 /// Attribute name used to mark allocs that should not be deallocated.
45 static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
46 
47 //===----------------------------------------------------------------------===//
48 // BufferizationOptions
49 //===----------------------------------------------------------------------===//
50 
51 // Default constructor for BufferizationOptions.
52 BufferizationOptions::BufferizationOptions() = default;
53 
54 bool BufferizationOptions::isOpAllowed(Operation *op) const {
55   // Special case: If function boundary bufferization is deactivated, do not
56   // allow ops that belong to the `func` dialect.
57   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
58   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
59     return false;
60 
61   // All other ops: Allow/disallow according to filter.
62   bool isAllowed = !filterHasAllowRule();
63   for (const OpFilterEntry &entry : opFilter) {
64     bool filterResult = entry.fn(op);
65     switch (entry.type) {
66     case OpFilterEntry::ALLOW:
67       isAllowed |= filterResult;
68       break;
69     case OpFilterEntry::DENY:
70       if (filterResult)
71         // DENY filter matches. This op is no allowed. (Even if other ALLOW
72         // filters may match.)
73         return false;
74     };
75   }
76   return isAllowed;
77 }
78 
79 BufferizableOpInterface
80 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
81   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
82   if (!bufferizableOp)
83     return nullptr;
84   if (!isOpAllowed(op))
85     return nullptr;
86   return bufferizableOp;
87 }
88 
89 BufferizableOpInterface
90 BufferizationOptions::dynCastBufferizableOp(Value value) const {
91   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
92     if (isOpAllowed(bufferizableOp.getOperation()))
93       return bufferizableOp;
94   return nullptr;
95 }
96 
97 void BufferizationOptions::addDialectStateInitializer(
98     StringRef name, const DialectStateInitFn &fn) {
99   stateInitializers.push_back(
100       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // Helper functions for BufferizableOpInterface
105 //===----------------------------------------------------------------------===//
106 
107 static void setInsertionPointAfter(OpBuilder &b, Value value) {
108   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
109     b.setInsertionPointToStart(bbArg.getOwner());
110   } else {
111     b.setInsertionPointAfter(value.getDefiningOp());
112   }
113 }
114 
115 /// Determine which OpOperand* will alias with `result` if the op is bufferized
116 /// in place. Return an empty vector if the op is not bufferizable.
117 SmallVector<OpOperand *>
118 AnalysisState::getAliasingOpOperand(OpResult result) const {
119   if (Operation *op = result.getDefiningOp())
120     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
121       return bufferizableOp.getAliasingOpOperand(result, *this);
122   return {};
123 }
124 
125 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
126 /// in place. Return an empty vector if the op is not bufferizable.
127 SmallVector<OpResult>
128 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
129   if (auto bufferizableOp =
130           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
131     return bufferizableOp.getAliasingOpResult(opOperand, *this);
132   return {};
133 }
134 
135 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
136 /// op is not bufferizable.
137 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
138   if (auto bufferizableOp =
139           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
140     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
141 
142   // Unknown op that returns a tensor. The inplace analysis does not support it.
143   // Conservatively return true.
144   return true;
145 }
146 
147 /// Return true if `opOperand` bufferizes to a memory write. Return
148 /// `true` if the op is not bufferizable.
149 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
150   if (auto bufferizableOp =
151           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
152     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
153 
154   // Unknown op that returns a tensor. The inplace analysis does not support it.
155   // Conservatively return true.
156   return true;
157 }
158 
159 /// Return true if `opOperand` does neither read nor write but bufferizes to an
160 /// alias. Return false if the op is not bufferizable.
161 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
162   if (auto bufferizableOp =
163           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
164     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
165 
166   // Unknown op that returns a tensor. The inplace analysis does not support it.
167   // Conservatively return false.
168   return false;
169 }
170 
171 /// Return true if the given value is read by an op that bufferizes to a memory
172 /// read. Also takes into account ops that create an alias but do not read by
173 /// themselves (e.g., ExtractSliceOp).
174 bool AnalysisState::isValueRead(Value value) const {
175   assert(value.getType().isa<TensorType>() && "expected TensorType");
176   SmallVector<OpOperand *> workingSet;
177   for (OpOperand &use : value.getUses())
178     workingSet.push_back(&use);
179 
180   while (!workingSet.empty()) {
181     OpOperand *uMaybeReading = workingSet.pop_back_val();
182     // Skip over all ops that neither read nor write (but create an alias).
183     if (bufferizesToAliasOnly(*uMaybeReading))
184       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
185         for (OpOperand &use : opResult.getUses())
186           workingSet.push_back(&use);
187     if (bufferizesToMemoryRead(*uMaybeReading))
188       return true;
189   }
190 
191   return false;
192 }
193 
194 // Starting from `value`, follow the use-def chain in reverse, always selecting
195 // the aliasing OpOperands. Find and return Values for which `condition`
196 // evaluates to true. OpOperands of such matching Values are not traversed any
197 // further.
198 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
199     Value value, llvm::function_ref<bool(Value)> condition) const {
200   llvm::SetVector<Value> result, workingSet;
201   workingSet.insert(value);
202 
203   while (!workingSet.empty()) {
204     Value value = workingSet.pop_back_val();
205     if (condition(value) || value.isa<BlockArgument>()) {
206       result.insert(value);
207       continue;
208     }
209 
210     OpResult opResult = value.cast<OpResult>();
211     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
212     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
213       result.insert(value);
214       continue;
215     }
216 
217     for (OpOperand *o : opOperands)
218       workingSet.insert(o->get());
219   }
220 
221   return result;
222 }
223 
224 // Find the Values of the last preceding write of a given Value.
225 llvm::SetVector<Value>
226 AnalysisState::findLastPrecedingWrite(Value value) const {
227   return findValueInReverseUseDefChain(value, [&](Value value) {
228     Operation *op = value.getDefiningOp();
229     if (!op)
230       return true;
231     auto bufferizableOp = options.dynCastBufferizableOp(op);
232     if (!bufferizableOp)
233       return true;
234     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
235   });
236 }
237 
238 AnalysisState::AnalysisState(const BufferizationOptions &options)
239     : options(options) {
240   for (const BufferizationOptions::AnalysisStateInitFn &fn :
241        options.stateInitializers)
242     fn(*this);
243 }
244 
245 // bufferization.to_memref is not allowed to change the rank.
246 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
247 #ifndef NDEBUG
248   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
249   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
250                                    rankedTensorType.getRank()) &&
251          "to_memref would be invalid: mismatching ranks");
252 #endif
253 }
254 
255 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
256                                         const BufferizationOptions &options) {
257   auto tensorType = tensor.getType().dyn_cast<TensorType>();
258   assert(tensorType && "unexpected non-tensor type");
259 
260   // Replace "%t = to_tensor %m" with %m.
261   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
262     return toTensorOp.memref();
263 
264   // Insert to_memref op.
265   OpBuilder::InsertionGuard g(rewriter);
266   setInsertionPointAfter(rewriter, tensor);
267   Type memrefType = getMemRefType(tensorType, options);
268   ensureToMemrefOpIsValid(tensor, memrefType);
269   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
270                                                     tensor);
271 }
272 
273 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
274 /// a new buffer and copy over data from the existing buffer if out-of-place
275 /// bufferization was decided.
276 FailureOr<Value>
277 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
278                               Optional<ForceInPlacability> overrideInPlace,
279                               Optional<Operation *> customCopyInsertionPoint) {
280   const BufferizationOptions &options = analysisState.getOptions();
281   OpBuilder::InsertionGuard guard(rewriter);
282   Operation *op = opOperand.getOwner();
283   Location loc = op->getLoc();
284   SmallVector<OpResult> aliasingOpResults =
285       analysisState.getAliasingOpResult(opOperand);
286   Value operand = opOperand.get();
287   Value operandBuffer = lookupBuffer(rewriter, operand, options);
288 
289   // Can `operandBuffer` be used directly or do we need a copy?
290   bool inplace =
291       overrideInPlace != FORCE_OUT_OF_PLACE &&
292       (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
293   if (inplace)
294     return operandBuffer;
295 
296   // Bufferizing out-of-place: Allocate a new buffer.
297   // Move insertion point right after `operandBuffer`. That is where the
298   // allocation should be inserted (in the absence of allocation hoisting).
299   setInsertionPointAfter(rewriter, operandBuffer);
300   // Allocate the result buffer. The buffer should be deallocated if the tensor
301   // is not yielded and deallocs are enabled in general.
302   bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
303     return getAnalysisState().isTensorYielded(v);
304   });
305   FailureOr<Value> resultBuffer = createAlloc(
306       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
307   if (failed(resultBuffer))
308     return failure();
309   // Do not copy the buffer if its contents are undefined.
310   if (analysisState.hasUndefinedContents(&opOperand))
311     return resultBuffer;
312   // Do not copy if the copied data is never read.
313   if (!aliasingOpResults.empty() &&
314       !analysisState.bufferizesToMemoryRead(opOperand) &&
315       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
316         return analysisState.isValueRead(opResult);
317       }))
318     return resultBuffer;
319   // Do not copy if this op does not read the data, but writes it.
320   if (analysisState.bufferizesToMemoryWrite(opOperand) &&
321       !analysisState.bufferizesToMemoryRead(opOperand))
322     return resultBuffer;
323 
324   if (customCopyInsertionPoint) {
325     rewriter.setInsertionPoint(*customCopyInsertionPoint);
326   } else {
327     // The copy happens right before the op that is bufferized.
328     rewriter.setInsertionPoint(op);
329   }
330   if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer)))
331     return failure();
332 
333   return resultBuffer;
334 }
335 
336 /// Return the buffer type for a given OpOperand (tensor) after bufferization.
337 BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const {
338   Value tensor = opOperand.get();
339   auto tensorType = tensor.getType().dyn_cast<TensorType>();
340   assert(tensorType && "unexpected non-tensor type");
341 
342   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
343     return toTensorOp.memref().getType().cast<BaseMemRefType>();
344 
345   return getMemRefType(tensorType, getOptions());
346 }
347 
348 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
349                                                   Operation *op,
350                                                   ValueRange values) {
351   assert(values.size() == op->getNumResults() &&
352          "expected one value per OpResult");
353   OpBuilder::InsertionGuard g(rewriter);
354 
355   // Replace all OpResults with the given values.
356   SmallVector<Value> replacements;
357   for (OpResult opResult : op->getOpResults()) {
358     Value replacement = values[opResult.getResultNumber()];
359     if (opResult.getType().isa<TensorType>()) {
360       // The OpResult is a tensor. Such values are replaced with memrefs during
361       // bufferization.
362       assert((replacement.getType().isa<MemRefType>() ||
363               replacement.getType().isa<UnrankedMemRefType>()) &&
364              "tensor op result should be replaced with a memref value");
365       // The existing uses of the OpResult still expect a tensor. Insert a
366       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
367       // loose all of its users and eventually DCE away.
368       rewriter.setInsertionPointAfter(op);
369       replacement = rewriter.create<bufferization::ToTensorOp>(
370           replacement.getLoc(), replacement);
371     }
372     replacements.push_back(replacement);
373   }
374 
375   rewriter.replaceOp(op, replacements);
376 }
377 
378 AlwaysCopyAnalysisState::AlwaysCopyAnalysisState(
379     const BufferizationOptions &options)
380     : AnalysisState(options) {
381   // Note: Allocations must be deallocated with a subsequent run of the buffer
382   // deallocation pass.
383   assert(!options.createDeallocs &&
384          "cannot create deallocs with AlwaysCopyBufferizationState");
385 }
386 
387 /// Return `true` if the given OpResult has been decided to bufferize inplace.
388 bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const {
389   // OpOperands that bufferize to a memory write are out-of-place, i.e., an
390   // alloc and copy is inserted.
391   return !bufferizesToMemoryWrite(opOperand);
392 }
393 
394 /// Return true if `v1` and `v2` bufferize to equivalent buffers.
395 bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
396                                                             Value v2) const {
397   // There is no analysis, so we do not know if the values are equivalent. The
398   // conservative answer is "false".
399   return false;
400 }
401 
402 /// Return `true` if the given tensor has undefined contents.
403 bool AlwaysCopyAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
404   // There is no analysis, so the conservative answer is "false".
405   return false;
406 }
407 
408 /// Return true if the given tensor (or an aliasing tensor) is yielded from
409 /// the containing block. Also include all aliasing tensors in the same block.
410 bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const {
411   // There is no analysis, so conservatively answer "true".
412   return true;
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // Bufferization-specific scoped alloc/dealloc insertion support.
417 //===----------------------------------------------------------------------===//
418 
419 /// Create a memref allocation with the given type and dynamic extents.
420 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
421                                                    MemRefType type,
422                                                    ValueRange dynShape) const {
423   if (allocationFn)
424     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
425 
426   // Default bufferallocation via AllocOp.
427   Value allocated = b.create<memref::AllocOp>(
428       loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
429   return allocated;
430 }
431 
432 /// Creates a memref deallocation. The given memref buffer must have been
433 /// allocated using `createAlloc`.
434 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
435                                                   Value allocatedBuffer) const {
436   if (deallocationFn)
437     return (*deallocationFn)(b, loc, allocatedBuffer);
438 
439   // Default buffer deallocation via DeallocOp.
440   b.create<memref::DeallocOp>(loc, allocatedBuffer);
441   return success();
442 }
443 
444 static MemRefType getContiguousMemRefType(ShapedType shapedType,
445                                           Attribute memorySpace = {}) {
446   MemRefLayoutAttrInterface layout = {};
447   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
448                          layout, memorySpace);
449 }
450 
451 /// Compute the type of the `memref` to use for allocating the buffer for
452 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
453 /// dynamic dimensions in the returned `memref` type.
454 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
455                                             Value shapedValue,
456                                             SmallVectorImpl<Value> &dynShape) {
457   MemRefType allocMemRefType =
458       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
459 
460   // Compute the dynamic part of the shape.
461   bool reifiedShapes = false;
462   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
463           shapedValue.getDefiningOp())) {
464     ReifiedRankedShapedTypeDims resultDims;
465     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
466       reifiedShapes = true;
467       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
468       auto &shape = resultDims[resultValue.getResultNumber()];
469       for (const auto &dim : enumerate(allocMemRefType.getShape()))
470         if (ShapedType::isDynamic(dim.value()))
471           dynShape.push_back(shape[dim.index()]);
472     }
473   }
474 
475   if (!reifiedShapes) {
476     for (const auto &dim : enumerate(allocMemRefType.getShape()))
477       if (ShapedType::isDynamic(dim.value())) {
478         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
479                 shapedValue.getType().isa<MemRefType>()) &&
480                "expected MemRef type");
481         dynShape.push_back(
482             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
483       }
484   }
485 
486   return allocMemRefType;
487 }
488 
489 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
490                                     ValueRange dynShape, bool skipDealloc) {
491   auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
492   allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
493   if (skipDealloc)
494     allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr());
495   return allocaOp.getResult();
496 }
497 
498 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
499 /// block in case of a bbArg).
500 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
501                                                  Value shapedValue,
502                                                  Optional<bool> dealloc) {
503   // Take a guard before anything else.
504   OpBuilder::InsertionGuard g(b);
505 
506   // Compute allocation memref type.
507   assert(shapedValue.getType().isa<ShapedType>());
508   SmallVector<Value> dynShape;
509   MemRefType allocMemRefType =
510       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
511 
512   // Should be the buffer be deallocated again or should we let it leak?
513   bool skipDealloc;
514   if (dealloc) {
515     skipDealloc = !dealloc.getValue();
516   } else {
517     assert(shapedValue.getType().isa<TensorType>() &&
518            "must specify `dealloc` if non-tensor value is passed");
519     // Buffer should be not be deallocated if deallocs are generally deactivated
520     // or if the tensor is yielded from a block.
521     skipDealloc = !getOptions().createDeallocs ||
522                   getAnalysisState().isTensorYielded(shapedValue);
523   }
524 
525   // Create the buffer allocation.
526   return createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
527 }
528 
529 /// Create a memory copy between two memref buffers.
530 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
531                                                  Value from, Value to) const {
532   if (memCpyFn)
533     return (*memCpyFn)(b, loc, from, to);
534 
535   b.create<memref::CopyOp>(loc, from, to);
536   return success();
537 }
538 
539 LogicalResult
540 bufferization::createAllocDeallocOps(Operation *op,
541                                      const BufferizationOptions &options,
542                                      bool onlyLeakingAllocs, bool *changed) {
543   IRRewriter rewriter(op->getContext());
544   if (changed)
545     *changed = false;
546 
547   // Bufferization creates memref.alloca ops. After bufferization, these must be
548   // rewritten to alloc/dealloc ops as specified in the bufferization options.
549   WalkResult status = op->walk([&](memref::AllocaOp allocaOp) {
550     // Ignore memref.alloca ops that were not created by the bufferization.
551     if (!allocaOp->hasAttr(kBufferAllocationAttr))
552       return WalkResult::skip();
553     // If `onlyLeakingAllocs`, process only ops that are marked as
554     // "skip dealloc".
555     bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr);
556     if (onlyLeakingAllocs && !skipDealloc)
557       return WalkResult::skip();
558 
559     // Create alloc.
560     Block *block = allocaOp->getBlock();
561     rewriter.setInsertionPoint(allocaOp);
562     FailureOr<Value> alloc =
563         options.createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
564                             allocaOp.dynamicSizes());
565     if (failed(alloc))
566       return WalkResult::interrupt();
567     rewriter.replaceOp(allocaOp, *alloc);
568     if (changed)
569       *changed = true;
570 
571     // Stop here if the buffer should not be deallocated.
572     if (skipDealloc)
573       return WalkResult::advance();
574 
575     // Create dealloc.
576     rewriter.setInsertionPoint(block->getTerminator());
577     if (failed(options.createDealloc(rewriter, alloc->getLoc(), *alloc)))
578       return WalkResult::interrupt();
579 
580     return WalkResult::advance();
581   });
582 
583   return success(!status.wasInterrupted());
584 }
585 
586 /// Try to hoist all new buffer allocations until the next hoisting barrier.
587 // TODO: Consolidate this function with the existing buffer hoisting pass.
588 LogicalResult
589 bufferization::hoistBufferAllocations(Operation *op,
590                                       const BufferizationOptions &options) {
591   // Nothing to do if allocation hoisting is deactivated.
592   if (!options.hoistAllocations)
593     return success();
594 
595   // Gather all buffer allocations that were created by the bufferization.
596   SmallVector<Operation *> allocaOps;
597   op->walk([&](memref::AllocaOp allocaOp) {
598     if (allocaOp->hasAttr(kBufferAllocationAttr))
599       allocaOps.push_back(allocaOp);
600   });
601 
602   for (Operation *allocaOp : allocaOps) {
603     // TODO: Hoisting of allocs with dynamic shape not implemented.
604     if (!allocaOp->getOpOperands().empty())
605       continue;
606 
607     Operation *op = allocaOp->getParentOp();
608     while (op) {
609       if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) {
610         if (bufferizableOp.isAllocationHoistingBarrier()) {
611           break;
612         }
613       } else {
614         // Op is not bufferizable: It may not be safe to hoist across this op.
615         break;
616       }
617       op = op->getParentOp();
618     }
619 
620     // FuncOp is an allocation hoisting barrier, so this should never happen.
621     assert(op && "allocation hoisting barrier not found");
622 
623     // Nothing to do if the insertion point is in the same block.
624     if (op == allocaOp->getParentOp())
625       continue;
626 
627     // `op` may have multiple blocks. Make sure that we insert in the right one.
628     SmallVector<Block *> blocks;
629     for (Region &r : op->getRegions())
630       for (Block &b : r.getBlocks())
631         blocks.push_back(&b);
632     auto *insertionBlock = llvm::find_if(
633         blocks, [&](Block *b) { return b->findAncestorOpInBlock(*allocaOp); });
634     assert(insertionBlock != blocks.end() && "owning block not found");
635 
636     // Move to the beginning of the block.
637     allocaOp->moveBefore(&(*insertionBlock)->front());
638   }
639 
640   return success();
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // Bufferization-specific BlockAndValueMapping support with debugging.
645 //===----------------------------------------------------------------------===//
646 
647 bool bufferization::isFunctionArgument(Value value) {
648   auto bbArg = value.dyn_cast<BlockArgument>();
649   if (!bbArg)
650     return false;
651   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
652 }
653 
654 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
655                                             const BufferizationOptions &options,
656                                             MemRefLayoutAttrInterface layout,
657                                             Attribute memorySpace) {
658   // Case 1: Unranked memref type.
659   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
660     assert(!layout && "UnrankedTensorType cannot have a layout map");
661     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
662                                    memorySpace);
663   }
664 
665   // Case 2: Ranked memref type with specified layout.
666   auto rankedTensorType = tensorType.cast<RankedTensorType>();
667   if (layout) {
668     return MemRefType::get(rankedTensorType.getShape(),
669                            rankedTensorType.getElementType(), layout,
670                            memorySpace);
671   }
672 
673   // Case 3: Configured with "fully dynamic layout maps".
674   if (options.unknownTypeConversion ==
675       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
676     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
677 
678   // Case 4: Configured with "static identity layout maps".
679   if (options.unknownTypeConversion ==
680       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
681     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
682 
683   llvm_unreachable("InferLayoutMap is an invalid option");
684 }
685 
686 BaseMemRefType
687 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
688                                                    Attribute memorySpace) {
689   // Case 1: Unranked memref type.
690   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
691     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
692                                    memorySpace);
693   }
694 
695   // Case 2: Ranked memref type.
696   auto rankedTensorType = tensorType.cast<RankedTensorType>();
697   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
698   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
699                                       ShapedType::kDynamicStrideOrOffset);
700   AffineMap stridedLayout = makeStridedLinearLayoutMap(
701       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
702   return MemRefType::get(rankedTensorType.getShape(),
703                          rankedTensorType.getElementType(), stridedLayout,
704                          memorySpace);
705 }
706 
707 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
708 /// the given tensor type is unranked, return an unranked MemRef type.
709 BaseMemRefType
710 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
711                                                      Attribute memorySpace) {
712   // Case 1: Unranked memref type.
713   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
714     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
715                                    memorySpace);
716   }
717 
718   // Case 2: Ranked memref type.
719   auto rankedTensorType = tensorType.cast<RankedTensorType>();
720   MemRefLayoutAttrInterface layout = {};
721   return MemRefType::get(rankedTensorType.getShape(),
722                          rankedTensorType.getElementType(), layout,
723                          memorySpace);
724 }
725