1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/IR/AsmState.h"
13 #include "mlir/IR/BlockAndValueMapping.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/IR/Value.h"
18 #include "llvm/Support/Debug.h"
19 
20 namespace mlir {
21 namespace bufferization {
22 
23 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
24 
25 } // namespace bufferization
26 } // namespace mlir
27 
28 #define DEBUG_TYPE "bufferizable-op-interface"
29 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
30 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
31 
32 using namespace mlir;
33 using namespace bufferization;
34 
35 /// Attribute name used to mark the bufferization layout for region
36 /// arguments during linalg comprehensive bufferization.
37 constexpr const ::llvm::StringLiteral
38     bufferization::BufferizableOpInterface::kBufferLayoutAttrName;
39 
40 /// Attribute name used to mark region arguments that can be bufferized
41 /// in-place during linalg comprehensive bufferization.
42 constexpr const ::llvm::StringLiteral
43     bufferization::BufferizableOpInterface::kInplaceableAttrName;
44 
45 /// Attribute name used to mark allocs that are created by the bufferization.
46 static const char *kBufferAllocationAttr = "bufferization.allocation";
47 
48 /// Attribute name used to mark allocs that should not be deallocated.
49 static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
50 
51 //===----------------------------------------------------------------------===//
52 // BufferizationOptions
53 //===----------------------------------------------------------------------===//
54 
55 // Default constructor for BufferizationOptions.
56 BufferizationOptions::BufferizationOptions() = default;
57 
58 BufferizableOpInterface
59 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
60   if (isOpAllowed(op))
61     return dyn_cast<BufferizableOpInterface>(op);
62   return nullptr;
63 }
64 
65 BufferizableOpInterface
66 BufferizationOptions::dynCastBufferizableOp(Value value) const {
67   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
68     if (isOpAllowed(bufferizableOp.getOperation()))
69       return bufferizableOp;
70   return nullptr;
71 }
72 
73 void BufferizationOptions::addDialectStateInitializer(
74     StringRef name, const DialectStateInitFn &fn) {
75   stateInitializers.push_back(
76       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // Helper functions for BufferizableOpInterface
81 //===----------------------------------------------------------------------===//
82 
83 static void setInsertionPointAfter(OpBuilder &b, Value value) {
84   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
85     b.setInsertionPointToStart(bbArg.getOwner());
86   } else {
87     b.setInsertionPointAfter(value.getDefiningOp());
88   }
89 }
90 
91 /// Determine which OpOperand* will alias with `result` if the op is bufferized
92 /// in place. Return an empty vector if the op is not bufferizable.
93 SmallVector<OpOperand *>
94 AnalysisState::getAliasingOpOperand(OpResult result) const {
95   if (Operation *op = result.getDefiningOp())
96     if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
97       return bufferizableOp.getAliasingOpOperand(result, *this);
98   return {};
99 }
100 
101 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
102 /// in place. Return an empty vector if the op is not bufferizable.
103 SmallVector<OpResult>
104 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
105   if (auto bufferizableOp =
106           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
107     return bufferizableOp.getAliasingOpResult(opOperand, *this);
108   return {};
109 }
110 
111 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
112 /// op is not bufferizable.
113 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
114   if (auto bufferizableOp =
115           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
116     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
117 
118   // Unknown op that returns a tensor. The inplace analysis does not support it.
119   // Conservatively return true.
120   return true;
121 }
122 
123 /// Return true if `opOperand` bufferizes to a memory write. Return
124 /// `true` if the op is not bufferizable.
125 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
126   if (auto bufferizableOp =
127           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
128     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
129 
130   // Unknown op that returns a tensor. The inplace analysis does not support it.
131   // Conservatively return true.
132   return true;
133 }
134 
135 /// Return true if `opOperand` does neither read nor write but bufferizes to an
136 /// alias. Return false if the op is not bufferizable.
137 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
138   if (auto bufferizableOp =
139           dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
140     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
141 
142   // Unknown op that returns a tensor. The inplace analysis does not support it.
143   // Conservatively return false.
144   return false;
145 }
146 
147 /// Return true if the given value is read by an op that bufferizes to a memory
148 /// read. Also takes into account ops that create an alias but do not read by
149 /// themselves (e.g., ExtractSliceOp).
150 bool AnalysisState::isValueRead(Value value) const {
151   assert(value.getType().isa<TensorType>() && "expected TensorType");
152   SmallVector<OpOperand *> workingSet;
153   for (OpOperand &use : value.getUses())
154     workingSet.push_back(&use);
155 
156   while (!workingSet.empty()) {
157     OpOperand *uMaybeReading = workingSet.pop_back_val();
158     // Skip over all ops that neither read nor write (but create an alias).
159     if (bufferizesToAliasOnly(*uMaybeReading))
160       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
161         for (OpOperand &use : opResult.getUses())
162           workingSet.push_back(&use);
163     if (bufferizesToMemoryRead(*uMaybeReading))
164       return true;
165   }
166 
167   return false;
168 }
169 
170 // Starting from `value`, follow the use-def chain in reverse, always selecting
171 // the aliasing OpOperands. Find and return Values for which `condition`
172 // evaluates to true. OpOperands of such matching Values are not traversed any
173 // further.
174 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
175     Value value, llvm::function_ref<bool(Value)> condition) const {
176   llvm::SetVector<Value> result, workingSet;
177   workingSet.insert(value);
178 
179   while (!workingSet.empty()) {
180     Value value = workingSet.pop_back_val();
181     if (condition(value) || value.isa<BlockArgument>()) {
182       result.insert(value);
183       continue;
184     }
185 
186     OpResult opResult = value.cast<OpResult>();
187     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
188     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
189       result.insert(value);
190       continue;
191     }
192 
193     for (OpOperand *o : opOperands)
194       workingSet.insert(o->get());
195   }
196 
197   return result;
198 }
199 
200 // Find the Values of the last preceding write of a given Value.
201 llvm::SetVector<Value>
202 AnalysisState::findLastPrecedingWrite(Value value) const {
203   return findValueInReverseUseDefChain(value, [&](Value value) {
204     Operation *op = value.getDefiningOp();
205     if (!op)
206       return true;
207     auto bufferizableOp = options.dynCastBufferizableOp(op);
208     if (!bufferizableOp)
209       return true;
210     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
211   });
212 }
213 
214 AnalysisState::AnalysisState(const BufferizationOptions &options)
215     : options(options) {
216   for (const BufferizationOptions::AnalysisStateInitFn &fn :
217        options.stateInitializers)
218     fn(*this);
219 }
220 
221 // bufferization.to_memref is not allowed to change the rank.
222 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
223 #ifndef NDEBUG
224   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
225   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
226                                    rankedTensorType.getRank()) &&
227          "to_memref would be invalid: mismatching ranks");
228 #endif
229 }
230 
231 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
232                                         const BufferizationOptions &options) {
233   auto tensorType = tensor.getType().dyn_cast<TensorType>();
234   assert(tensorType && "unexpected non-tensor type");
235 
236   // Replace "%t = to_tensor %m" with %m.
237   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
238     return toTensorOp.memref();
239 
240   // Insert to_memref op.
241   OpBuilder::InsertionGuard g(rewriter);
242   setInsertionPointAfter(rewriter, tensor);
243   Type memrefType = getMemRefType(tensorType, options);
244   ensureToMemrefOpIsValid(tensor, memrefType);
245   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
246                                                     tensor);
247 }
248 
249 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
250 /// a new buffer and copy over data from the existing buffer if out-of-place
251 /// bufferization is necessary.
252 FailureOr<Value>
253 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
254                               bool forceInPlace,
255                               Optional<Operation *> customCopyInsertionPoint) {
256   const BufferizationOptions &options = analysisState.getOptions();
257   OpBuilder::InsertionGuard guard(rewriter);
258   Operation *op = opOperand.getOwner();
259   Location loc = op->getLoc();
260   SmallVector<OpResult> aliasingOpResults =
261       analysisState.getAliasingOpResult(opOperand);
262   Value operand = opOperand.get();
263   Value operandBuffer = lookupBuffer(rewriter, operand, options);
264 
265   if (forceInPlace || analysisState.isInPlace(opOperand))
266     return operandBuffer;
267 
268   // Bufferizing out-of-place: Allocate a new buffer.
269   // Move insertion point right after `operandBuffer`. That is where the
270   // allocation should be inserted (in the absence of allocation hoisting).
271   setInsertionPointAfter(rewriter, operandBuffer);
272   // Allocate the result buffer. The buffer should be deallocated if the tensor
273   // is not yielded and deallocs are enabled in general.
274   bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
275     return getAnalysisState().isTensorYielded(v);
276   });
277   FailureOr<Value> resultBuffer = createAlloc(
278       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
279   if (failed(resultBuffer))
280     return failure();
281   // Do not copy if the last preceding writes of `operand` are ops that do
282   // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
283   // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
284   // use-def chain, it returns that value, regardless of whether it is a
285   // memory write or not.
286   SetVector<Value> lastWrites = analysisState.findLastPrecedingWrite(operand);
287   if (llvm::none_of(lastWrites, [&](Value lastWrite) {
288         if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
289           return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
290                                               analysisState);
291         return true;
292       }))
293     return resultBuffer;
294   // Do not copy if the copied data is never read.
295   if (!aliasingOpResults.empty() &&
296       !analysisState.bufferizesToMemoryRead(opOperand) &&
297       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
298         return analysisState.isValueRead(opResult);
299       }))
300     return resultBuffer;
301   // Do not copy if this op does not read the data, but writes it.
302   if (analysisState.bufferizesToMemoryWrite(opOperand) &&
303       !analysisState.bufferizesToMemoryRead(opOperand))
304     return resultBuffer;
305 
306   if (customCopyInsertionPoint) {
307     rewriter.setInsertionPoint(*customCopyInsertionPoint);
308   } else {
309     // The copy happens right before the op that is bufferized.
310     rewriter.setInsertionPoint(op);
311   }
312   if (failed(
313           createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options)))
314     return failure();
315 
316   return resultBuffer;
317 }
318 
319 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
320                                                   Operation *op,
321                                                   ValueRange values) {
322   assert(values.size() == op->getNumResults() &&
323          "expected one value per OpResult");
324   OpBuilder::InsertionGuard g(rewriter);
325 
326   // Replace all OpResults with the given values.
327   SmallVector<Value> replacements;
328   for (OpResult opResult : op->getOpResults()) {
329     Value replacement = values[opResult.getResultNumber()];
330     if (opResult.getType().isa<TensorType>()) {
331       // The OpResult is a tensor. Such values are replaced with memrefs during
332       // bufferization.
333       assert((replacement.getType().isa<MemRefType>() ||
334               replacement.getType().isa<UnrankedMemRefType>()) &&
335              "tensor op result should be replaced with a memref value");
336       // The existing uses of the OpResult still expect a tensor. Insert a
337       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
338       // loose all of its users and eventually DCE away.
339       rewriter.setInsertionPointAfter(op);
340       replacement = rewriter.create<bufferization::ToTensorOp>(
341           replacement.getLoc(), replacement);
342     }
343     replacements.push_back(replacement);
344   }
345 
346   rewriter.replaceOp(op, replacements);
347 }
348 
349 AlwaysCopyAnalysisState::AlwaysCopyAnalysisState(
350     const BufferizationOptions &options)
351     : AnalysisState(options) {
352   // Note: Allocations must be deallocated with a subsequent run of the buffer
353   // deallocation pass.
354   assert(!options.createDeallocs &&
355          "cannot create deallocs with AlwaysCopyBufferizationState");
356 }
357 
358 /// Return `true` if the given OpResult has been decided to bufferize inplace.
359 bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const {
360   // OpOperands that bufferize to a memory write are out-of-place, i.e., an
361   // alloc and copy is inserted.
362   return !bufferizesToMemoryWrite(opOperand);
363 }
364 
365 /// Return true if `v1` and `v2` bufferize to equivalent buffers.
366 bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
367                                                             Value v2) const {
368   // There is no analysis, so we do not know if the values are equivalent. The
369   // conservative answer is "false".
370   return false;
371 }
372 
373 /// Return true if the given tensor (or an aliasing tensor) is yielded from
374 /// the containing block. Also include all aliasing tensors in the same block.
375 bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const {
376   // There is no analysis, so conservatively answer "true".
377   return true;
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // Bufferization-specific scoped alloc/dealloc insertion support.
382 //===----------------------------------------------------------------------===//
383 
384 /// Create a memref allocation with the given type and dynamic extents.
385 static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
386                                     ValueRange dynShape,
387                                     const BufferizationOptions &options) {
388   if (options.allocationFn)
389     return (*options.allocationFn)(b, loc, type, dynShape,
390                                    options.bufferAlignment);
391 
392   // Default bufferallocation via AllocOp.
393   Value allocated = b.create<memref::AllocOp>(
394       loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
395   return allocated;
396 }
397 
398 /// Creates a memref deallocation. The given memref buffer must have been
399 /// allocated using `createAlloc`.
400 LogicalResult
401 bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
402                              const BufferizationOptions &options) {
403   if (options.deallocationFn)
404     return (*options.deallocationFn)(b, loc, allocatedBuffer);
405 
406   // Default buffer deallocation via DeallocOp.
407   b.create<memref::DeallocOp>(loc, allocatedBuffer);
408   return success();
409 }
410 
411 /// Compute the type of the `memref` to use for allocating the buffer for
412 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
413 /// dynamic dimensions in the returned `memref` type.
414 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
415                                             Value shapedValue,
416                                             SmallVectorImpl<Value> &dynShape) {
417   MemRefType allocMemRefType =
418       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
419 
420   // Compute the dynamic part of the shape.
421   bool reifiedShapes = false;
422   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
423           shapedValue.getDefiningOp())) {
424     ReifiedRankedShapedTypeDims resultDims;
425     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
426       reifiedShapes = true;
427       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
428       auto &shape = resultDims[resultValue.getResultNumber()];
429       for (const auto &dim : enumerate(allocMemRefType.getShape()))
430         if (ShapedType::isDynamic(dim.value()))
431           dynShape.push_back(shape[dim.index()]);
432     }
433   }
434 
435   if (!reifiedShapes) {
436     for (const auto &dim : enumerate(allocMemRefType.getShape()))
437       if (ShapedType::isDynamic(dim.value())) {
438         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
439                 shapedValue.getType().isa<MemRefType>()) &&
440                "expected MemRef type");
441         dynShape.push_back(
442             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
443       }
444   }
445 
446   return allocMemRefType;
447 }
448 
449 static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
450                                     ValueRange dynShape, bool skipDealloc) {
451   auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
452   allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
453   if (skipDealloc)
454     allocaOp->setAttr(kSkipDeallocAttr, b.getUnitAttr());
455   return allocaOp.getResult();
456 }
457 
458 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
459 /// block in case of a bbArg).
460 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
461                                                  Value shapedValue,
462                                                  Optional<bool> dealloc) {
463   // Take a guard before anything else.
464   OpBuilder::InsertionGuard g(b);
465 
466   // Compute allocation memref type.
467   assert(shapedValue.getType().isa<ShapedType>());
468   MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
469   SmallVector<Value> dynShape;
470   MemRefType allocMemRefType =
471       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
472 
473   // Should be the buffer be deallocated again or should we let it leak?
474   bool skipDealloc;
475   if (dealloc) {
476     skipDealloc = !dealloc.getValue();
477   } else {
478     assert(shapedValue.getType().isa<TensorType>() &&
479            "must specify `dealloc` if non-tensor value is passed");
480     // Buffer should be not be deallocated if deallocs are generally deactivated
481     // or if the tensor is yielded from a block.
482     skipDealloc = !getOptions().createDeallocs ||
483                   getAnalysisState().isTensorYielded(shapedValue);
484   }
485 
486   // Create the buffer allocation.
487   Value alloc =
488       createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
489 
490   // Insert a cast if a different type was requested.
491   if (memRefType && memRefType != allocMemRefType) {
492     assert(memref::CastOp::areCastCompatible(allocMemRefType, memRefType) &&
493            "createAlloc: cast incompatible");
494     alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
495   }
496 
497   return alloc;
498 }
499 
500 /// Create a memory copy between two memref buffers.
501 LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
502                                           Value from, Value to,
503                                           const BufferizationOptions &options) {
504   if (options.memCpyFn)
505     return (*options.memCpyFn)(b, loc, from, to);
506 
507   b.create<memref::CopyOp>(loc, from, to);
508   return success();
509 }
510 
511 LogicalResult
512 bufferization::createAllocDeallocOps(Operation *op,
513                                      const BufferizationOptions &options,
514                                      bool onlyLeakingAllocs) {
515   IRRewriter rewriter(op->getContext());
516 
517   // Bufferization creates memref.alloca ops. After bufferization, these must be
518   // rewritten to alloc/dealloc ops as specified in the bufferization options.
519   WalkResult status = op->walk([&](memref::AllocaOp allocaOp) {
520     // Ignore memref.alloca ops that were not created by the bufferization.
521     if (!allocaOp->hasAttr(kBufferAllocationAttr))
522       return WalkResult::skip();
523     // If `onlyLeakingAllocs`, process only ops that are marked as
524     // "skip dealloc".
525     bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr);
526     if (onlyLeakingAllocs && !skipDealloc)
527       return WalkResult::skip();
528 
529     // Create alloc.
530     Block *block = allocaOp->getBlock();
531     rewriter.setInsertionPoint(allocaOp);
532     FailureOr<Value> alloc =
533         createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
534                     allocaOp.dynamicSizes(), options);
535     if (failed(alloc))
536       return WalkResult::interrupt();
537     rewriter.replaceOp(allocaOp, *alloc);
538 
539     // Stop here if the buffer should not be deallocated.
540     if (skipDealloc)
541       return WalkResult::advance();
542 
543     // Create dealloc.
544     rewriter.setInsertionPoint(block->getTerminator());
545     if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
546       return WalkResult::interrupt();
547 
548     return WalkResult::advance();
549   });
550 
551   return success(!status.wasInterrupted());
552 }
553 
554 /// Try to hoist all new buffer allocations until the next hoisting barrier.
555 // TODO: Consolidate this function with the existing buffer hoisting pass.
556 LogicalResult
557 bufferization::hoistBufferAllocations(Operation *op,
558                                       const BufferizationOptions &options) {
559   // Nothing to do if allocation hoisting is deactivated.
560   if (!options.hoistAllocations)
561     return success();
562 
563   // Gather all buffer allocations that were created by the bufferization.
564   SmallVector<Operation *> allocaOps;
565   op->walk([&](memref::AllocaOp allocaOp) {
566     if (allocaOp->hasAttr(kBufferAllocationAttr))
567       allocaOps.push_back(allocaOp);
568   });
569 
570   for (Operation *allocaOp : allocaOps) {
571     // TODO: Hoisting of allocs with dynamic shape not implemented.
572     if (!allocaOp->getOpOperands().empty())
573       continue;
574 
575     Operation *op = allocaOp->getParentOp();
576     while (op) {
577       if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) {
578         if (bufferizableOp.isAllocationHoistingBarrier()) {
579           break;
580         }
581       } else {
582         // Op is not bufferizable: It may not be safe to hoist across this op.
583         break;
584       }
585       op = op->getParentOp();
586     }
587 
588     // FuncOp is an allocation hoisting barrier, so this should never happen.
589     assert(op && "allocation hoisting barrier not found");
590 
591     // Nothing to do if the insertion point is in the same block.
592     if (op == allocaOp->getParentOp())
593       continue;
594 
595     // `op` may have multiple blocks. Make sure that we insert in the right one.
596     SmallVector<Block *> blocks;
597     for (Region &r : op->getRegions())
598       for (Block &b : r.getBlocks())
599         blocks.push_back(&b);
600     auto *insertionBlock = llvm::find_if(
601         blocks, [&](Block *b) { return b->findAncestorOpInBlock(*allocaOp); });
602     assert(insertionBlock != blocks.end() && "owning block not found");
603 
604     // Move to the beginning of the block.
605     allocaOp->moveBefore(&(*insertionBlock)->front());
606   }
607 
608   return success();
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // Bufferization-specific BlockAndValueMapping support with debugging.
613 //===----------------------------------------------------------------------===//
614 
615 bool bufferization::isFunctionArgument(Value value) {
616   auto bbArg = value.dyn_cast<BlockArgument>();
617   if (!bbArg)
618     return false;
619   return isa<FuncOp>(bbArg.getOwner()->getParentOp());
620 }
621 
622 MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType,
623                                                   Attribute memorySpace) {
624   MemRefLayoutAttrInterface layout = {};
625   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
626                          layout, memorySpace);
627 }
628 
629 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
630                                             const BufferizationOptions &options,
631                                             MemRefLayoutAttrInterface layout,
632                                             Attribute memorySpace) {
633   // Case 1: Unranked memref type.
634   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
635     assert(!layout && "UnrankedTensorType cannot have a layout map");
636     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
637                                    memorySpace);
638   }
639 
640   // Case 2: Ranked memref type with specified layout. If fully dynamic layout
641   // maps are not requested, generate a type with `layout`, which is empty (no
642   // layout map) by default.
643   auto rankedTensorType = tensorType.cast<RankedTensorType>();
644   if (layout || !options.fullyDynamicLayoutMaps) {
645     return MemRefType::get(rankedTensorType.getShape(),
646                            rankedTensorType.getElementType(), layout,
647                            memorySpace);
648   }
649 
650   // Case 3: Ranked memref type with unspecified layout. Choose the most dynamic
651   // one.
652   // TODO: address space decisions to connect with the actual alloc.
653   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
654   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
655                                       ShapedType::kDynamicStrideOrOffset);
656   AffineMap stridedLayout = makeStridedLinearLayoutMap(
657       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
658   return MemRefType::get(rankedTensorType.getShape(),
659                          rankedTensorType.getElementType(), stridedLayout,
660                          memorySpace);
661 }
662