1 //===- BufferizableOpInterface.cpp - Bufferizable Ops  ---=----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
10 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/IR/Value.h"
20 #include "llvm/Support/Debug.h"
21 
22 //===----------------------------------------------------------------------===//
23 // BufferizableOpInterface
24 //===----------------------------------------------------------------------===//
25 
26 namespace mlir {
27 namespace bufferization {
28 
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
30 
31 } // namespace bufferization
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "bufferizable-op-interface"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
37 
38 using namespace mlir;
39 using namespace bufferization;
40 
41 /// Attribute name used to mark region arguments that can be bufferized
42 /// in-place during linalg comprehensive bufferization.
43 constexpr const ::llvm::StringLiteral
44     bufferization::BufferizableOpInterface::kInplaceableAttrName;
45 
46 /// Create an AllocTensorOp for the given shaped value. Only ranked tensors are
47 /// supported at the moment. If `copy` is set, the shaped value is copied.
48 /// Otherwise, a tensor with undefined contents is allocated.
49 static Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
50                                           Value shapedValue, bool escape,
51                                           bool copy = true) {
52   auto tensorType = shapedValue.getType().dyn_cast<RankedTensorType>();
53   assert(tensorType && "only RankedTensorType supported at the moment");
54   Value alloc;
55   if (!copy) {
56     // No copy needed: Just allocate.
57     SmallVector<Value> dynamicSizes;
58     for (int64_t i = 0; i < tensorType.getRank(); ++i)
59       if (tensorType.isDynamicDim(i))
60         dynamicSizes.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
61     alloc = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
62                                     /*copy=*/Value(), escape);
63   } else {
64     // Allocate and copy.
65     alloc = b.create<AllocTensorOp>(loc, tensorType,
66                                     /*dynamicSizes=*/ValueRange(), shapedValue,
67                                     escape);
68   }
69   return alloc;
70 }
71 
72 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
73     RewriterBase &rewriter, const AnalysisState &state) {
74   OpBuilder::InsertionGuard g(rewriter);
75   Operation *op = getOperation();
76   SmallVector<OpOperand *> outOfPlaceOpOperands;
77   DenseSet<OpOperand *> copiedOpOperands;
78   SmallVector<OpResult> outOfPlaceOpResults;
79   DenseSet<OpResult> copiedOpResults;
80 
81   // Find all out-of-place OpOperands.
82   for (OpOperand &opOperand : op->getOpOperands()) {
83     Type operandType = opOperand.get().getType();
84     if (!operandType.isa<TensorType>())
85       continue;
86     if (state.isInPlace(opOperand))
87       continue;
88     if (operandType.isa<UnrankedTensorType>())
89       return op->emitError("copies of unranked tensors are not supported");
90 
91     SmallVector<OpResult> aliasingOpResults =
92         state.getAliasingOpResult(opOperand);
93     if (aliasingOpResults.size() == 1 &&
94         !state.bufferizesToMemoryWrite(opOperand) &&
95         state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
96       // The op itself does not write but may create exactly one alias. Instead
97       // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
98       // be smaller than the OpOperand (e.g., in the case of an extract_slice,
99       // where the result is usually a smaller part of the source).
100       outOfPlaceOpResults.push_back(aliasingOpResults.front());
101       if (!state.canOmitTensorCopy(opOperand))
102         copiedOpResults.insert(aliasingOpResults.front());
103     } else {
104       // In all other cases, make a copy of the OpOperand.
105       outOfPlaceOpOperands.push_back(&opOperand);
106       if (!state.canOmitTensorCopy(opOperand))
107         copiedOpOperands.insert(&opOperand);
108     }
109   }
110 
111   // Insert copies of OpOperands.
112   rewriter.setInsertionPoint(op);
113   for (OpOperand *opOperand : outOfPlaceOpOperands) {
114     SmallVector<OpResult> aliasingOpResults =
115         state.getAliasingOpResult(*opOperand);
116     bool escape = llvm::any_of(
117         aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
118     Value copy = allocateTensorForShapedValue(
119         rewriter, op->getLoc(), opOperand->get(), escape,
120         copiedOpOperands.contains(opOperand));
121     rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
122   }
123 
124   // Insert copies of OpResults.
125   rewriter.setInsertionPointAfter(op);
126   for (OpResult opResult : outOfPlaceOpResults) {
127     bool escape = state.isTensorYielded(opResult);
128     Value copy =
129         allocateTensorForShapedValue(rewriter, op->getLoc(), opResult, escape,
130                                      copiedOpResults.count(opResult));
131     SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
132         opResult.getUses(), [](OpOperand &use) { return &use; }));
133     for (OpOperand *use : uses) {
134       // Do not update the alloc_tensor op that we just created.
135       if (use->getOwner() != copy.getDefiningOp())
136         rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
137     }
138   }
139 
140   return success();
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // OpFilter
145 //===----------------------------------------------------------------------===//
146 
147 bool OpFilter::isOpAllowed(Operation *op) const {
148   // All other ops: Allow/disallow according to filter.
149   bool isAllowed = !hasAllowRule();
150   for (const Entry &entry : entries) {
151     bool filterResult = entry.fn(op);
152     switch (entry.type) {
153     case Entry::ALLOW:
154       isAllowed |= filterResult;
155       break;
156     case Entry::DENY:
157       if (filterResult)
158         // DENY filter matches. This op is no allowed. (Even if other ALLOW
159         // filters may match.)
160         return false;
161     };
162   }
163   return isAllowed;
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // BufferizationOptions
168 //===----------------------------------------------------------------------===//
169 
170 // Default constructor for BufferizationOptions.
171 BufferizationOptions::BufferizationOptions() = default;
172 
173 bool BufferizationOptions::isOpAllowed(Operation *op) const {
174   // Special case: If function boundary bufferization is deactivated, do not
175   // allow ops that belong to the `func` dialect.
176   bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
177   if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
178     return false;
179 
180   return opFilter.isOpAllowed(op);
181 }
182 
183 BufferizableOpInterface
184 BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
185   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
186   if (!bufferizableOp)
187     return nullptr;
188   if (!isOpAllowed(op))
189     return nullptr;
190   return bufferizableOp;
191 }
192 
193 BufferizableOpInterface
194 BufferizationOptions::dynCastBufferizableOp(Value value) const {
195   if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
196     if (isOpAllowed(bufferizableOp.getOperation()))
197       return bufferizableOp;
198   return nullptr;
199 }
200 
201 void BufferizationOptions::addDialectStateInitializer(
202     StringRef name, const DialectStateInitFn &fn) {
203   stateInitializers.push_back(
204       [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // Helper functions for BufferizableOpInterface
209 //===----------------------------------------------------------------------===//
210 
211 static void setInsertionPointAfter(OpBuilder &b, Value value) {
212   if (auto bbArg = value.dyn_cast<BlockArgument>()) {
213     b.setInsertionPointToStart(bbArg.getOwner());
214   } else {
215     b.setInsertionPointAfter(value.getDefiningOp());
216   }
217 }
218 
219 /// Determine which OpOperand* will alias with `result` if the op is bufferized
220 /// in place. Return an empty vector if the op is not bufferizable.
221 SmallVector<OpOperand *>
222 AnalysisState::getAliasingOpOperand(OpResult result) const {
223   if (Operation *op = result.getDefiningOp())
224     if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
225       return bufferizableOp.getAliasingOpOperand(result, *this);
226   return {};
227 }
228 
229 /// Determine which OpResult will alias with `opOperand` if the op is bufferized
230 /// in place. Return an empty vector if the op is not bufferizable.
231 SmallVector<OpResult>
232 AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
233   if (auto bufferizableOp =
234           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
235     return bufferizableOp.getAliasingOpResult(opOperand, *this);
236   return {};
237 }
238 
239 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
240 /// op is not bufferizable.
241 bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
242   if (auto bufferizableOp =
243           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
244     return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
245 
246   // Unknown op that returns a tensor. The inplace analysis does not support it.
247   // Conservatively return true.
248   return true;
249 }
250 
251 /// Return true if `opOperand` bufferizes to a memory write. Return
252 /// `true` if the op is not bufferizable.
253 bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
254   if (auto bufferizableOp =
255           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
256     return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
257 
258   // Unknown op that returns a tensor. The inplace analysis does not support it.
259   // Conservatively return true.
260   return true;
261 }
262 
263 /// Return true if `opOperand` does neither read nor write but bufferizes to an
264 /// alias. Return false if the op is not bufferizable.
265 bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
266   if (auto bufferizableOp =
267           getOptions().dynCastBufferizableOp(opOperand.getOwner()))
268     return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
269 
270   // Unknown op that returns a tensor. The inplace analysis does not support it.
271   // Conservatively return false.
272   return false;
273 }
274 
275 /// Return true if the given value is read by an op that bufferizes to a memory
276 /// read. Also takes into account ops that create an alias but do not read by
277 /// themselves (e.g., ExtractSliceOp).
278 bool AnalysisState::isValueRead(Value value) const {
279   assert(value.getType().isa<TensorType>() && "expected TensorType");
280   SmallVector<OpOperand *> workingSet;
281   for (OpOperand &use : value.getUses())
282     workingSet.push_back(&use);
283 
284   while (!workingSet.empty()) {
285     OpOperand *uMaybeReading = workingSet.pop_back_val();
286     // Skip over all ops that neither read nor write (but create an alias).
287     if (bufferizesToAliasOnly(*uMaybeReading))
288       for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
289         for (OpOperand &use : opResult.getUses())
290           workingSet.push_back(&use);
291     if (bufferizesToMemoryRead(*uMaybeReading))
292       return true;
293   }
294 
295   return false;
296 }
297 
298 // Starting from `value`, follow the use-def chain in reverse, always selecting
299 // the aliasing OpOperands. Find and return Values for which `condition`
300 // evaluates to true. OpOperands of such matching Values are not traversed any
301 // further.
302 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
303     Value value, llvm::function_ref<bool(Value)> condition) const {
304   llvm::SetVector<Value> result, workingSet;
305   workingSet.insert(value);
306 
307   while (!workingSet.empty()) {
308     Value value = workingSet.pop_back_val();
309     if (condition(value) || value.isa<BlockArgument>()) {
310       result.insert(value);
311       continue;
312     }
313 
314     OpResult opResult = value.cast<OpResult>();
315     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
316     if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
317       result.insert(value);
318       continue;
319     }
320 
321     for (OpOperand *o : opOperands)
322       workingSet.insert(o->get());
323   }
324 
325   return result;
326 }
327 
328 // Find the Values of the last preceding write of a given Value.
329 llvm::SetVector<Value>
330 AnalysisState::findLastPrecedingWrite(Value value) const {
331   return findValueInReverseUseDefChain(value, [&](Value value) {
332     Operation *op = value.getDefiningOp();
333     if (!op)
334       return true;
335     auto bufferizableOp = options.dynCastBufferizableOp(op);
336     if (!bufferizableOp)
337       return true;
338     return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
339   });
340 }
341 
342 AnalysisState::AnalysisState(const BufferizationOptions &options)
343     : options(options) {
344   for (const BufferizationOptions::AnalysisStateInitFn &fn :
345        options.stateInitializers)
346     fn(*this);
347 }
348 
349 bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
350   // Do not copy if the tensor has undefined contents.
351   if (hasUndefinedContents(&opOperand))
352     return true;
353 
354   // Do not copy if the buffer of the tensor is entirely overwritten (with
355   // values that do not depend on the old tensor).
356   if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
357     return true;
358 
359   // Do not copy if the tensor is never read.
360   SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
361   if (!bufferizesToMemoryRead(opOperand) &&
362       llvm::none_of(aliasingOpResults,
363                     [&](OpResult opResult) { return isValueRead(opResult); }))
364     return true;
365 
366   // Default: Cannot omit the copy.
367   return false;
368 }
369 
370 // bufferization.to_memref is not allowed to change the rank.
371 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
372 #ifndef NDEBUG
373   auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
374   assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
375                                    rankedTensorType.getRank()) &&
376          "to_memref would be invalid: mismatching ranks");
377 #endif
378 }
379 
380 Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
381                                         const BufferizationOptions &options) {
382   auto tensorType = tensor.getType().dyn_cast<TensorType>();
383   assert(tensorType && "unexpected non-tensor type");
384 
385   // Replace "%t = to_tensor %m" with %m.
386   if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
387     return toTensorOp.memref();
388 
389   // Insert to_memref op.
390   OpBuilder::InsertionGuard g(rewriter);
391   setInsertionPointAfter(rewriter, tensor);
392   Type memrefType = getMemRefType(tensorType, options);
393   ensureToMemrefOpIsValid(tensor, memrefType);
394   return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
395                                                     tensor);
396 }
397 
398 /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
399 /// a new buffer and copy over data from the existing buffer if out-of-place
400 /// bufferization was decided.
401 FailureOr<Value>
402 BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
403                               Optional<ForceInPlacability> overrideInPlace,
404                               Optional<Operation *> customCopyInsertionPoint) {
405   const BufferizationOptions &options = analysisState.getOptions();
406   OpBuilder::InsertionGuard guard(rewriter);
407   Operation *op = opOperand.getOwner();
408   Location loc = op->getLoc();
409   SmallVector<OpResult> aliasingOpResults =
410       analysisState.getAliasingOpResult(opOperand);
411   Value operand = opOperand.get();
412   Value operandBuffer = lookupBuffer(rewriter, operand, options);
413 
414   // Can `operandBuffer` be used directly or do we need a copy?
415   bool inplace =
416       overrideInPlace != FORCE_OUT_OF_PLACE &&
417       (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
418   if (inplace)
419     return operandBuffer;
420 
421   // Bufferizing out-of-place: Allocate a new buffer.
422   // Move insertion point right after `operandBuffer`. That is where the
423   // allocation should be inserted (in the absence of allocation hoisting).
424   setInsertionPointAfter(rewriter, operandBuffer);
425   // Allocate the result buffer. The buffer should be deallocated if the tensor
426   // is not yielded and deallocs are enabled in general.
427   bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) {
428     return getAnalysisState().isTensorYielded(v);
429   });
430   FailureOr<Value> resultBuffer = createAlloc(
431       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
432   if (failed(resultBuffer))
433     return failure();
434   // Do not copy the buffer if its contents are undefined.
435   if (analysisState.hasUndefinedContents(&opOperand))
436     return resultBuffer;
437   // Do not copy if the copied data is never read.
438   if (!aliasingOpResults.empty() &&
439       !analysisState.bufferizesToMemoryRead(opOperand) &&
440       llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
441         return analysisState.isValueRead(opResult);
442       }))
443     return resultBuffer;
444   // Do not copy if this op does not read the data, but writes it.
445   if (analysisState.bufferizesToMemoryWrite(opOperand) &&
446       !analysisState.bufferizesToMemoryRead(opOperand))
447     return resultBuffer;
448 
449   if (customCopyInsertionPoint) {
450     rewriter.setInsertionPoint(*customCopyInsertionPoint);
451   } else {
452     // The copy happens right before the op that is bufferized.
453     rewriter.setInsertionPoint(op);
454   }
455   if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer)))
456     return failure();
457 
458   return resultBuffer;
459 }
460 
461 /// Return the buffer type for a given Value (tensor) after bufferization.
462 BaseMemRefType BufferizationState::getBufferType(Value value) const {
463   auto tensorType = value.getType().dyn_cast<TensorType>();
464   assert(tensorType && "unexpected non-tensor type");
465 
466   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
467     return toTensorOp.memref().getType().cast<BaseMemRefType>();
468 
469   return getMemRefType(tensorType, getOptions());
470 }
471 
472 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
473                                                   Operation *op,
474                                                   ValueRange values) {
475   assert(values.size() == op->getNumResults() &&
476          "expected one value per OpResult");
477   OpBuilder::InsertionGuard g(rewriter);
478 
479   // Replace all OpResults with the given values.
480   SmallVector<Value> replacements;
481   for (OpResult opResult : op->getOpResults()) {
482     Value replacement = values[opResult.getResultNumber()];
483     if (opResult.getType().isa<TensorType>()) {
484       // The OpResult is a tensor. Such values are replaced with memrefs during
485       // bufferization.
486       assert((replacement.getType().isa<MemRefType>() ||
487               replacement.getType().isa<UnrankedMemRefType>()) &&
488              "tensor op result should be replaced with a memref value");
489       // The existing uses of the OpResult still expect a tensor. Insert a
490       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
491       // loose all of its users and eventually DCE away.
492       rewriter.setInsertionPointAfter(op);
493       replacement = rewriter.create<bufferization::ToTensorOp>(
494           replacement.getLoc(), replacement);
495     }
496     replacements.push_back(replacement);
497   }
498 
499   rewriter.replaceOp(op, replacements);
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // Bufferization-specific scoped alloc/dealloc insertion support.
504 //===----------------------------------------------------------------------===//
505 
506 /// Create a memref allocation with the given type and dynamic extents.
507 FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
508                                                    MemRefType type,
509                                                    ValueRange dynShape) const {
510   if (allocationFn)
511     return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
512 
513   // Default bufferallocation via AllocOp.
514   Value allocated = b.create<memref::AllocOp>(
515       loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
516   return allocated;
517 }
518 
519 /// Creates a memref deallocation. The given memref buffer must have been
520 /// allocated using `createAlloc`.
521 LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
522                                                   Value allocatedBuffer) const {
523   if (deallocationFn)
524     return (*deallocationFn)(b, loc, allocatedBuffer);
525 
526   // Default buffer deallocation via DeallocOp.
527   b.create<memref::DeallocOp>(loc, allocatedBuffer);
528   return success();
529 }
530 
531 static MemRefType getContiguousMemRefType(ShapedType shapedType,
532                                           Attribute memorySpace = {}) {
533   MemRefLayoutAttrInterface layout = {};
534   return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
535                          layout, memorySpace);
536 }
537 
538 /// Compute the type of the `memref` to use for allocating the buffer for
539 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
540 /// dynamic dimensions in the returned `memref` type.
541 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
542                                             Value shapedValue,
543                                             SmallVectorImpl<Value> &dynShape) {
544   MemRefType allocMemRefType =
545       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
546 
547   // Compute the dynamic part of the shape.
548   bool reifiedShapes = false;
549   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
550           shapedValue.getDefiningOp())) {
551     ReifiedRankedShapedTypeDims resultDims;
552     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
553       reifiedShapes = true;
554       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
555       auto &shape = resultDims[resultValue.getResultNumber()];
556       for (const auto &dim : enumerate(allocMemRefType.getShape()))
557         if (ShapedType::isDynamic(dim.value()))
558           dynShape.push_back(shape[dim.index()]);
559     }
560   }
561 
562   if (!reifiedShapes) {
563     for (const auto &dim : enumerate(allocMemRefType.getShape()))
564       if (ShapedType::isDynamic(dim.value())) {
565         assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
566                 shapedValue.getType().isa<MemRefType>()) &&
567                "expected MemRef type");
568         dynShape.push_back(
569             b.create<memref::DimOp>(loc, shapedValue, dim.index()));
570       }
571   }
572 
573   return allocMemRefType;
574 }
575 
576 /// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
577 /// block in case of a bbArg).
578 FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
579                                                  Value shapedValue,
580                                                  Optional<bool> dealloc) {
581   // Take a guard before anything else.
582   OpBuilder::InsertionGuard g(b);
583 
584   // Compute allocation memref type.
585   assert(shapedValue.getType().isa<ShapedType>());
586   SmallVector<Value> dynShape;
587   MemRefType allocMemRefType =
588       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
589 
590   // Create the buffer allocation.
591   FailureOr<Value> buffer =
592       getOptions().createAlloc(b, loc, allocMemRefType, dynShape);
593   if (failed(buffer))
594     return failure();
595 
596   // Should be the buffer be deallocated again or should we let it leak?
597   if (dealloc) {
598     if (!dealloc.getValue())
599       return *buffer;
600   } else {
601     assert(shapedValue.getType().isa<TensorType>() &&
602            "must specify `dealloc` if non-tensor value is passed");
603     // Buffer should be not be deallocated if deallocs are generally deactivated
604     // or if the tensor is yielded from a block.
605     if (!getOptions().createDeallocs ||
606         getAnalysisState().isTensorYielded(shapedValue))
607       return *buffer;
608   }
609 
610   // Create buffer deallocation.
611   b.setInsertionPoint(b.getInsertionBlock()->getTerminator());
612   if (failed(getOptions().createDealloc(b, loc, *buffer)))
613     return failure();
614 
615   return *buffer;
616 }
617 
618 /// Create a memory copy between two memref buffers.
619 LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
620                                                  Value from, Value to) const {
621   if (memCpyFn)
622     return (*memCpyFn)(b, loc, from, to);
623 
624   b.create<memref::CopyOp>(loc, from, to);
625   return success();
626 }
627 
628 //===----------------------------------------------------------------------===//
629 // Bufferization-specific BlockAndValueMapping support with debugging.
630 //===----------------------------------------------------------------------===//
631 
632 bool bufferization::isFunctionArgument(Value value) {
633   auto bbArg = value.dyn_cast<BlockArgument>();
634   if (!bbArg)
635     return false;
636   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
637 }
638 
639 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
640                                             const BufferizationOptions &options,
641                                             MemRefLayoutAttrInterface layout,
642                                             Attribute memorySpace) {
643   // Case 1: Unranked memref type.
644   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
645     assert(!layout && "UnrankedTensorType cannot have a layout map");
646     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
647                                    memorySpace);
648   }
649 
650   // Case 2: Ranked memref type with specified layout.
651   auto rankedTensorType = tensorType.cast<RankedTensorType>();
652   if (layout) {
653     return MemRefType::get(rankedTensorType.getShape(),
654                            rankedTensorType.getElementType(), layout,
655                            memorySpace);
656   }
657 
658   // Case 3: Configured with "fully dynamic layout maps".
659   if (options.unknownTypeConversion ==
660       BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
661     return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
662 
663   // Case 4: Configured with "static identity layout maps".
664   if (options.unknownTypeConversion ==
665       BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
666     return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
667 
668   llvm_unreachable("InferLayoutMap is an invalid option");
669 }
670 
671 BaseMemRefType
672 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
673                                                    Attribute memorySpace) {
674   // Case 1: Unranked memref type.
675   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
676     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
677                                    memorySpace);
678   }
679 
680   // Case 2: Ranked memref type.
681   auto rankedTensorType = tensorType.cast<RankedTensorType>();
682   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
683   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
684                                       ShapedType::kDynamicStrideOrOffset);
685   AffineMap stridedLayout = makeStridedLinearLayoutMap(
686       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
687   return MemRefType::get(rankedTensorType.getShape(),
688                          rankedTensorType.getElementType(), stridedLayout,
689                          memorySpace);
690 }
691 
692 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
693 /// the given tensor type is unranked, return an unranked MemRef type.
694 BaseMemRefType
695 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
696                                                      Attribute memorySpace) {
697   // Case 1: Unranked memref type.
698   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
699     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
700                                    memorySpace);
701   }
702 
703   // Case 2: Ranked memref type.
704   auto rankedTensorType = tensorType.cast<RankedTensorType>();
705   MemRefLayoutAttrInterface layout = {};
706   return MemRefType::get(rankedTensorType.getShape(),
707                          rankedTensorType.getElementType(), layout,
708                          memorySpace);
709 }
710