1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21 
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25 
26 namespace mlir {
27 namespace bufferization {
28 
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30 
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37 
38 using namespace mlir;
39 using namespace bufferization;
40 
41 /// Attribute name used to mark region arguments that can be bufferized
42 /// in-place during linalg comprehensive bufferization.
43 constexpr const ::llvm::StringLiteral
44     bufferization::BufferizableOpInterface::kInplaceableAttrName;
45 
46 /// Create an AllocTensorOp for the given shaped value. Only ranked tensors are
47 /// supported at the moment. If `copy` is set, the shaped value is copied.
48 /// Otherwise, a tensor with undefined contents is allocated.
49 static Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
50                                           Value shapedValue, bool escape,
51                                           bool copy = true) {
52   auto tensorType = shapedValue.getType().dyn_cast<RankedTensorType>();
53   assert(tensorType && "only RankedTensorType supported at the moment");
54   Value alloc;
55   if (!copy) {
56     // No copy needed: Just allocate.
57     SmallVector<Value> dynamicSizes;
58     for (int64_t i = 0; i < tensorType.getRank(); ++i)
59       if (tensorType.isDynamicDim(i))
60         dynamicSizes.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
61     alloc = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
62                                     /*copy=*/Value(), escape);
63   } else {
64     // Allocate and copy.
65     alloc = b.create<AllocTensorOp>(loc, tensorType,
66                                     /*dynamicSizes=*/ValueRange(), shapedValue,
67                                     escape);
68   }
69   return alloc;
70 }
71 
72 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
73     RewriterBase &rewriter, const AnalysisState &state) {
74   OpBuilder::InsertionGuard g(rewriter);
75   Operation *op = getOperation();
76   SmallVector<OpOperand *> outOfPlaceOpOperands;
77   DenseSet<OpOperand *> copiedOpOperands;
78   SmallVector<OpResult> outOfPlaceOpResults;
79   DenseSet<OpResult> copiedOpResults;
80 
81   // Find all out-of-place OpOperands.
82   for (OpOperand &opOperand : op->getOpOperands()) {
83     Type operandType = opOperand.get().getType();
84     if (!operandType.isa<TensorType>())
85       continue;
86     if (state.isInPlace(opOperand))
87       continue;
88     if (operandType.isa<UnrankedTensorType>())
89       return op->emitError("copies of unranked tensors are not supported");
90 
91     SmallVector<OpResult> aliasingOpResults =
92         state.getAliasingOpResult(opOperand);
93     if (aliasingOpResults.size() == 1 &&
94         !state.bufferizesToMemoryWrite(opOperand) &&
95         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
96       // The op itself does not write but may create exactly one alias. Instead
97       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
98       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
99       // where the result is usually a smaller part of the source).
100       outOfPlaceOpResults.push_back(aliasingOpResults.front());
101       if (!state.canOmitTensorCopy(opOperand))
102         copiedOpResults.insert(aliasingOpResults.front());
103     } else {
104       // In all other cases, make a copy of the OpOperand.
105       outOfPlaceOpOperands.push_back(&opOperand);
106       if (!state.canOmitTensorCopy(opOperand))
107         copiedOpOperands.insert(&opOperand);
108     }
109   }
110 
111   // Insert copies of OpOperands.
112   rewriter.setInsertionPoint(op);
113   for (OpOperand *opOperand : outOfPlaceOpOperands) {
114     SmallVector<OpResult> aliasingOpResults =
115         state.getAliasingOpResult(*opOperand);
116     bool escape = llvm::any_of(
117         aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
118     Value copy = allocateTensorForShapedValue(
119         rewriter, op->getLoc(), opOperand->get(), escape,
120         copiedOpOperands.contains(opOperand));
121     rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
122   }
123 
124   // Insert copies of OpResults.
125   rewriter.setInsertionPointAfter(op);
126   for (OpResult opResult : outOfPlaceOpResults) {
127     bool escape = state.isTensorYielded(opResult);
128     Value copy =
129         allocateTensorForShapedValue(rewriter, op->getLoc(), opResult, escape,
130                                      copiedOpResults.count(opResult));
131     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
132         opResult.getUses(), [](OpOperand &use) { return &use; }));
133     for (OpOperand *use : uses) {
134       // Do not update the alloc_tensor op that we just created.
135       if (use->getOwner() != copy.getDefiningOp())
136         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
137     }
138   }
139 
140   return success();
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // OpFilter
145 //===----------------------------------------------------------------------===//
146 
147 bool OpFilter::isOpAllowed(Operation *op) const {
148   // All other ops: Allow/disallow according to filter.
149   bool isAllowed = !hasAllowRule();
150   for (const Entry &entry : entries) {
151     bool filterResult = entry.fn(op);
152     switch (entry.type) {
153     case Entry::ALLOW:
154       isAllowed |= filterResult;
155       break;
156     case Entry::DENY:
157       if (filterResult)
158         // DENY filter matches. This op is no allowed. (Even if other ALLOW
159         // filters may match.)
160         return false;
161     };
162   }
163   return isAllowed;
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // BufferizationOptions
168 //===----------------------------------------------------------------------===//
169 
170 // Default constructor for BufferizationOptions.
171 BufferizationOptions::BufferizationOptions() = default;
172 
173 bool BufferizationOptions::isOpAllowed(Operation *op) const {
174   // Special case: If function boundary bufferization is deactivated, do not
175   // allow ops that belong to the `func` dialect.
176   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
177   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
178     return false;
179 
180   return opFilter.isOpAllowed(op);
181 }
182 
183 BufferizableOpInterface
184 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
185   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
186   if (!bufferizableOp)
187     return nullptr;
188   if (!isOpAllowed(op))
189     return nullptr;
190   return bufferizableOp;
191 }
192 
193 BufferizableOpInterface
194 BufferizationOptions::dynCastBufferizableOp(Value value) const {
195   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
196     if (isOpAllowed(bufferizableOp.getOperation()))
197       return bufferizableOp;
198   return nullptr;
199 }
200 
201 void BufferizationOptions::addDialectStateInitializer(
202     StringRef name, const DialectStateInitFn &fn) {
203   stateInitializers.push_back(
204       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // Helper functions for BufferizableOpInterface
209 //===----------------------------------------------------------------------===//
210 
211 static void setInsertionPointAfter(OpBuilder &b, Value value) {
212   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
213     b.setInsertionPointToStart(bbArg.getOwner());
214   } else {
215     b.setInsertionPointAfter(value.getDefiningOp());
216   }
217 }
218 
219 /// Determine which OpOperand* will alias with `result` if the op is bufferized
220 /// in place. Return an empty vector if the op is not bufferizable.
221 SmallVector<OpOperand *>
222 AnalysisState::getAliasingOpOperand(OpResult result) const {
223   if (Operation *op = result.getDefiningOp())
224     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
225       return bufferizableOp.getAliasingOpOperand(result, *this);
226   return {};
227 }
228 
229 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
230 /// in place. Return an empty vector if the op is not bufferizable.
231 SmallVector<OpResult>
232 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
233   if (auto bufferizableOp =
234           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
235     return bufferizableOp.getAliasingOpResult(opOperand, *this);
236   return {};
237 }
238 
239 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
240 /// op is not bufferizable.
241 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
242   if (auto bufferizableOp =
243           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
244     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
245 
246   // Unknown op that returns a tensor. The inplace analysis does not support it.
247   // Conservatively return true.
248   return true;
249 }
250 
251 /// Return true if `opOperand` bufferizes to a memory write. Return
252 /// `true` if the op is not bufferizable.
253 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
254   if (auto bufferizableOp =
255           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
256     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
257 
258   // Unknown op that returns a tensor. The inplace analysis does not support it.
259   // Conservatively return true.
260   return true;
261 }
262 
263 /// Return true if `opOperand` does neither read nor write but bufferizes to an
264 /// alias. Return false if the op is not bufferizable.
265 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
266   if (auto bufferizableOp =
267           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
268     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
269 
270   // Unknown op that returns a tensor. The inplace analysis does not support it.
271   // Conservatively return false.
272   return false;
273 }
274 
275 /// Return true if the given value is read by an op that bufferizes to a memory
276 /// read. Also takes into account ops that create an alias but do not read by
277 /// themselves (e.g., ExtractSliceOp).
278 bool AnalysisState::isValueRead(Value value) const {
279   assert(value.getType().isa<TensorType>() && "expected TensorType");
280   SmallVector<OpOperand *> workingSet;
281   for (OpOperand &use : value.getUses())
282     workingSet.push_back(&use);
283 
284   while (!workingSet.empty()) {
285     OpOperand *uMaybeReading = workingSet.pop_back_val();
286     // Skip over all ops that neither read nor write (but create an alias).
287     if (bufferizesToAliasOnly(*uMaybeReading))
288       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
289         for (OpOperand &use : opResult.getUses())
290           workingSet.push_back(&use);
291     if (bufferizesToMemoryRead(*uMaybeReading))
292       return true;
293   }
294 
295   return false;
296 }
297 
298 // Starting from `value`, follow the use-def chain in reverse, always selecting
299 // the aliasing OpOperands. Find and return Values for which `condition`
300 // evaluates to true. OpOperands of such matching Values are not traversed any
301 // further.
302 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
303     Value value, llvm::function_ref<bool(Value)> condition) const {
304   llvm::SetVector<Value> result, workingSet;
305   workingSet.insert(value);
306 
307   while (!workingSet.empty()) {
308     Value value = workingSet.pop_back_val();
309     if (condition(value) || value.isa<BlockArgument>()) {
310       result.insert(value);
311       continue;
312     }
313 
314     OpResult opResult = value.cast<OpResult>();
315     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
316     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
317       result.insert(value);
318       continue;
319     }
320 
321     for (OpOperand *o : opOperands)
322       workingSet.insert(o->get());
323   }
324 
325   return result;
326 }
327 
328 // Find the Values of the last preceding write of a given Value.
329 llvm::SetVector<Value>
330 AnalysisState::findLastPrecedingWrite(Value value) const {
331   return findValueInReverseUseDefChain(value, [&](Value value) {
332     Operation *op = value.getDefiningOp();
333     if (!op)
334       return true;
335     auto bufferizableOp = options.dynCastBufferizableOp(op);
336     if (!bufferizableOp)
337       return true;
338     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
339   });
340 }
341 
342 AnalysisState::AnalysisState(const BufferizationOptions &options)
343     : options(options) {
344   for (const BufferizationOptions::AnalysisStateInitFn &fn :
345        options.stateInitializers)
346     fn(*this);
347 }
348 
349 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
350   // Do not copy if the tensor has undefined contents.
351   if (hasUndefinedContents(&opOperand))
352     return true;
353 
354   // Do not copy if the buffer of the tensor is entirely overwritten (with
355   // values that do not depend on the old tensor).
356   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
357     return true;
358 
359   // Do not copy if the tensor is never read.
360   SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
361   if (!bufferizesToMemoryRead(opOperand) &&
362       llvm::none_of(aliasingOpResults,
363                     [&](OpResult opResult) { return isValueRead(opResult); }))
364     return true;
365 
366   // Default: Cannot omit the copy.
367   return false;
368 }
369 
370 bool AnalysisState::isInPlace(OpOperand &opOperand) const {
371   // In the absence of analysis information, OpOperands that bufferize to a
372   // memory write are out-of-place, i.e., an alloc and copy is inserted.
373   return !bufferizesToMemoryWrite(opOperand);
374 }
375 
376 bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
377   // In the absence of analysis information, we do not know if the values are
378   // equivalent. The conservative answer is "false".
379   return false;
380 }
381 
382 bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
383   // In the absence of analysis information, we do not know if the values may be
384   // aliasing. The conservative answer is "true".
385   return false;
386 }
387 
388 bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
389   // In the absence of analysis information, the conservative answer is "false".
390   return false;
391 }
392 
393 bool AnalysisState::isTensorYielded(Value tensor) const {
394   // In the absence of analysis information, the conservative answer is "true".
395   return true;
396 }
397 
398 // bufferization.to_memref is not allowed to change the rank.
399 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
400 #ifndef NDEBUG
401   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
402   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
403                                    rankedTensorType.getRank()) &&
404          "to_memref would be invalid: mismatching ranks");
405 #endif
406 }
407 
408 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
409                                         const BufferizationOptions &options) {
410   auto tensorType = tensor.getType().dyn_cast<TensorType>();
411   assert(tensorType && "unexpected non-tensor type");
412 
413   // Replace "%t = to_tensor %m" with %m.
414   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
415     return toTensorOp.memref();
416 
417   // Insert to_memref op.
418   OpBuilder::InsertionGuard g(rewriter);
419   setInsertionPointAfter(rewriter, tensor);
420   Type memrefType = getMemRefType(tensorType, options);
421   ensureToMemrefOpIsValid(tensor, memrefType);
422   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
423                                                     tensor);
424 }
425 
426 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
427 /// a new buffer and copy over data from the existing buffer if out-of-place
428 /// bufferization was decided.
429 FailureOr<Value>
430 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
431                               Optional<ForceInPlacability> overrideInPlace,
432                               Optional<Operation *> customCopyInsertionPoint) {
433   const BufferizationOptions &options = analysisState.getOptions();
434   OpBuilder::InsertionGuard guard(rewriter);
435   Operation *op = opOperand.getOwner();
436   Location loc = op->getLoc();
437   SmallVector<OpResult> aliasingOpResults =
438       analysisState.getAliasingOpResult(opOperand);
439   Value operand = opOperand.get();
440   Value operandBuffer = lookupBuffer(rewriter, operand, options);
441 
442   // Can `operandBuffer` be used directly or do we need a copy?
443   bool inplace =
444       overrideInPlace != FORCE_OUT_OF_PLACE &&
445       (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
446   if (inplace)
447     return operandBuffer;
448 
449   // Bufferizing out-of-place: Allocate a new buffer.
450   // Move insertion point right after `operandBuffer`. That is where the
451   // allocation should be inserted (in the absence of allocation hoisting).
452   setInsertionPointAfter(rewriter, operandBuffer);
453   // Allocate the result buffer. The buffer should be deallocated if the tensor
454   // is not yielded and deallocs are enabled in general.
455   bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
456     return getAnalysisState().isTensorYielded(v);
457   });
458   FailureOr<Value> resultBuffer = createAlloc(
459       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
460   if (failed(resultBuffer))
461     return failure();
462   // Do not copy the buffer if its contents are undefined.
463   if (analysisState.hasUndefinedContents(&opOperand))
464     return resultBuffer;
465   // Do not copy if the copied data is never read.
466   if (!aliasingOpResults.empty() &&
467       !analysisState.bufferizesToMemoryRead(opOperand) &&
468       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
469         return analysisState.isValueRead(opResult);
470       }))
471     return resultBuffer;
472   // Do not copy if this op does not read the data, but writes it.
473   if (analysisState.bufferizesToMemoryWrite(opOperand) &&
474       !analysisState.bufferizesToMemoryRead(opOperand))
475     return resultBuffer;
476 
477   if (customCopyInsertionPoint) {
478     rewriter.setInsertionPoint(*customCopyInsertionPoint);
479   } else {
480     // The copy happens right before the op that is bufferized.
481     rewriter.setInsertionPoint(op);
482   }
483   if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer)))
484     return failure();
485 
486   return resultBuffer;
487 }
488 
489 /// Return the buffer type for a given Value (tensor) after bufferization.
490 BaseMemRefType BufferizationState::getBufferType(Value value) const {
491   auto tensorType = value.getType().dyn_cast<TensorType>();
492   assert(tensorType && "unexpected non-tensor type");
493 
494   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
495     return toTensorOp.memref().getType().cast<BaseMemRefType>();
496 
497   return getMemRefType(tensorType, getOptions());
498 }
499 
500 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
501                                                   Operation *op,
502                                                   ValueRange values) {
503   assert(values.size() == op->getNumResults() &&
504          "expected one value per OpResult");
505   OpBuilder::InsertionGuard g(rewriter);
506 
507   // Replace all OpResults with the given values.
508   SmallVector<Value> replacements;
509   for (OpResult opResult : op->getOpResults()) {
510     Value replacement = values[opResult.getResultNumber()];
511     if (opResult.getType().isa<TensorType>()) {
512       // The OpResult is a tensor. Such values are replaced with memrefs during
513       // bufferization.
514       assert((replacement.getType().isa<MemRefType>() ||
515               replacement.getType().isa<UnrankedMemRefType>()) &&
516              "tensor op result should be replaced with a memref value");
517       // The existing uses of the OpResult still expect a tensor. Insert a
518       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
519       // loose all of its users and eventually DCE away.
520       rewriter.setInsertionPointAfter(op);
521       replacement = rewriter.create<bufferization::ToTensorOp>(
522           replacement.getLoc(), replacement);
523     }
524     replacements.push_back(replacement);
525   }
526 
527   rewriter.replaceOp(op, replacements);
528 }
529 
530 //===----------------------------------------------------------------------===//
531 // Bufferization-specific scoped alloc/dealloc insertion support.
532 //===----------------------------------------------------------------------===//
533 
534 /// Create a memref allocation with the given type and dynamic extents.
535 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
536                                                    MemRefType type,
537                                                    ValueRange dynShape) const {
538   if (allocationFn)
539     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
540 
541   // Default bufferallocation via AllocOp.
542   Value allocated = b.create<memref::AllocOp>(
543       loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
544   return allocated;
545 }
546 
547 /// Creates a memref deallocation. The given memref buffer must have been
548 /// allocated using `createAlloc`.
549 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
550                                                   Value allocatedBuffer) const {
551   if (deallocationFn)
552     return (*deallocationFn)(b, loc, allocatedBuffer);
553 
554   // Default buffer deallocation via DeallocOp.
555   b.create<memref::DeallocOp>(loc, allocatedBuffer);
556   return success();
557 }
558 
559 static MemRefType getContiguousMemRefType(ShapedType shapedType,
560                                           Attribute memorySpace = {}) {
561   MemRefLayoutAttrInterface layout = {};
562   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
563                          layout, memorySpace);
564 }
565 
566 /// Compute the type of the `memref` to use for allocating the buffer for
567 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
568 /// dynamic dimensions in the returned `memref` type.
569 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
570                                             Value shapedValue,
571                                             SmallVectorImpl<Value> &dynShape) {
572   MemRefType allocMemRefType =
573       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
574 
575   // Compute the dynamic part of the shape.
576   bool reifiedShapes = false;
577   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
578           shapedValue.getDefiningOp())) {
579     ReifiedRankedShapedTypeDims resultDims;
580     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
581       reifiedShapes = true;
582       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
583       auto &shape = resultDims[resultValue.getResultNumber()];
584       for (const auto &dim : enumerate(allocMemRefType.getShape()))
585         if (ShapedType::isDynamic(dim.value()))
586           dynShape.push_back(shape[dim.index()]);
587     }
588   }
589 
590   if (!reifiedShapes) {
591     for (const auto &dim : enumerate(allocMemRefType.getShape()))
592       if (ShapedType::isDynamic(dim.value())) {
593         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
594                 shapedValue.getType().isa<MemRefType>()) &&
595                "expected MemRef type");
596         dynShape.push_back(
597             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
598       }
599   }
600 
601   return allocMemRefType;
602 }
603 
604 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
605 /// block in case of a bbArg).
606 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
607                                                  Value shapedValue,
608                                                  Optional<bool> dealloc) {
609   // Take a guard before anything else.
610   OpBuilder::InsertionGuard g(b);
611 
612   // Compute allocation memref type.
613   assert(shapedValue.getType().isa<ShapedType>());
614   SmallVector<Value> dynShape;
615   MemRefType allocMemRefType =
616       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
617 
618   // Create the buffer allocation.
619   FailureOr<Value> buffer =
620       getOptions().createAlloc(b, loc, allocMemRefType, dynShape);
621   if (failed(buffer))
622     return failure();
623 
624   // Should be the buffer be deallocated again or should we let it leak?
625   if (dealloc) {
626     if (!dealloc.getValue())
627       return *buffer;
628   } else {
629     assert(shapedValue.getType().isa<TensorType>() &&
630            "must specify `dealloc` if non-tensor value is passed");
631     // Buffer should be not be deallocated if deallocs are generally deactivated
632     // or if the tensor is yielded from a block.
633     if (!getOptions().createDeallocs ||
634         getAnalysisState().isTensorYielded(shapedValue))
635       return *buffer;
636   }
637 
638   // Create buffer deallocation.
639   b.setInsertionPoint(b.getInsertionBlock()->getTerminator());
640   if (failed(getOptions().createDealloc(b, loc, *buffer)))
641     return failure();
642 
643   return *buffer;
644 }
645 
646 /// Create a memory copy between two memref buffers.
647 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
648                                                  Value from, Value to) const {
649   if (memCpyFn)
650     return (*memCpyFn)(b, loc, from, to);
651 
652   b.create<memref::CopyOp>(loc, from, to);
653   return success();
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // Bufferization-specific BlockAndValueMapping support with debugging.
658 //===----------------------------------------------------------------------===//
659 
660 bool bufferization::isFunctionArgument(Value value) {
661   auto bbArg = value.dyn_cast<BlockArgument>();
662   if (!bbArg)
663     return false;
664   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
665 }
666 
667 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
668                                             const BufferizationOptions &options,
669                                             MemRefLayoutAttrInterface layout,
670                                             Attribute memorySpace) {
671   // Case 1: Unranked memref type.
672   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
673     assert(!layout && "UnrankedTensorType cannot have a layout map");
674     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
675                                    memorySpace);
676   }
677 
678   // Case 2: Ranked memref type with specified layout.
679   auto rankedTensorType = tensorType.cast<RankedTensorType>();
680   if (layout) {
681     return MemRefType::get(rankedTensorType.getShape(),
682                            rankedTensorType.getElementType(), layout,
683                            memorySpace);
684   }
685 
686   // Case 3: Configured with "fully dynamic layout maps".
687   if (options.unknownTypeConversion ==
688       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
689     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
690 
691   // Case 4: Configured with "static identity layout maps".
692   if (options.unknownTypeConversion ==
693       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
694     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
695 
696   llvm_unreachable("InferLayoutMap is an invalid option");
697 }
698 
699 BaseMemRefType
700 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
701                                                    Attribute memorySpace) {
702   // Case 1: Unranked memref type.
703   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
704     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
705                                    memorySpace);
706   }
707 
708   // Case 2: Ranked memref type.
709   auto rankedTensorType = tensorType.cast<RankedTensorType>();
710   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
711   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
712                                       ShapedType::kDynamicStrideOrOffset);
713   AffineMap stridedLayout = makeStridedLinearLayoutMap(
714       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
715   return MemRefType::get(rankedTensorType.getShape(),
716                          rankedTensorType.getElementType(), stridedLayout,
717                          memorySpace);
718 }
719 
720 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
721 /// the given tensor type is unranked, return an unranked MemRef type.
722 BaseMemRefType
723 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
724                                                      Attribute memorySpace) {
725   // Case 1: Unranked memref type.
726   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
727     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
728                                    memorySpace);
729   }
730 
731   // Case 2: Ranked memref type.
732   auto rankedTensorType = tensorType.cast<RankedTensorType>();
733   MemRefLayoutAttrInterface layout = {};
734   return MemRefType::get(rankedTensorType.getShape(),
735                          rankedTensorType.getElementType(), layout,
736                          memorySpace);
737 }
738